]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
Refactoring to make node config accessible without actually creating a Node. refs...
authorDaira Hopwood <daira@jacaranda.org>
Fri, 17 Apr 2015 21:27:19 +0000 (22:27 +0100)
committerDaira Hopwood <daira@jacaranda.org>
Fri, 17 Apr 2015 21:31:40 +0000 (22:31 +0100)
Signed-off-by: Daira Hopwood <daira@jacaranda.org>
src/allmydata/node.py

index 8d1ca15e233ae35070043f0b363f9451f859b530..19e28ad1e084ecea1914adb56609e6556a363729 100644 (file)
@@ -67,50 +67,7 @@ class UnescapedHashError(Exception):
                 % quote_output("[%s]%s = %s" % self.args))
 
 
-class Node(service.MultiService):
-    # this implements common functionality of both Client nodes and Introducer
-    # nodes.
-    NODETYPE = "unknown NODETYPE"
-    PORTNUMFILE = None
-    CERTFILE = "node.pem"
-    GENERATED_FILES = []
-
-    def __init__(self, basedir=u"."):
-        service.MultiService.__init__(self)
-        self.basedir = abspath_expanduser_unicode(unicode(basedir))
-        self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE)
-        self._tub_ready_observerlist = observer.OneShotObserverList()
-        fileutil.make_dirs(os.path.join(self.basedir, "private"), 0700)
-        fileutil.write(os.path.join(self.basedir, "private", "README"), PRIV_README, mode="")
-
-        # creates self.config
-        self.read_config()
-        nickname_utf8 = self.get_config("node", "nickname", "<unspecified>")
-        self.nickname = nickname_utf8.decode("utf-8")
-        assert type(self.nickname) is unicode
-
-        self.init_tempdir()
-        self.create_tub()
-        self.logSource="Node"
-
-        self.setup_ssh()
-        self.setup_logging()
-        self.log("Node constructed. " + get_package_versions_string())
-        iputil.increase_rlimits()
-
-    def init_tempdir(self):
-        tempdir_config = self.get_config("node", "tempdir", "tmp").decode('utf-8')
-        tempdir = abspath_expanduser_unicode(tempdir_config, base=self.basedir)
-        if not os.path.exists(tempdir):
-            fileutil.make_dirs(tempdir)
-        tempfile.tempdir = tempdir
-        # this should cause twisted.web.http (which uses
-        # tempfile.TemporaryFile) to put large request bodies in the given
-        # directory. Without this, the default temp dir is usually /tmp/,
-        # which is frequently too small.
-        test_name = tempfile.mktemp()
-        _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir)
-
+class ConfigMixin:
     @staticmethod
     def _contains_unescaped_hash(item):
         characters = iter(item)
@@ -174,18 +131,6 @@ class Node(service.MultiService):
             if os.path.exists(tahoe_cfg):
                 raise
 
-        cfg_tubport = self.get_config("node", "tub.port", "")
-        if not cfg_tubport:
-            # For 'tub.port', tahoe.cfg overrides the individual file on
-            # disk. So only read self._portnumfile if tahoe.cfg doesn't
-            # provide a value.
-            try:
-                file_tubport = fileutil.read(self._portnumfile).strip()
-                self.set_config("node", "tub.port", file_tubport)
-            except EnvironmentError:
-                if os.path.exists(self._portnumfile):
-                    raise
-
     def error_about_old_config_files(self):
         """ If any old configuration files are detected, raise OldConfigError. """
 
@@ -205,47 +150,6 @@ class Node(service.MultiService):
             twlog.msg(e)
             raise e
 
-    def create_tub(self):
-        certfile = os.path.join(self.basedir, "private", self.CERTFILE)
-        self.tub = Tub(certFile=certfile)
-        self.tub.setOption("logLocalFailures", True)
-        self.tub.setOption("logRemoteFailures", True)
-        self.tub.setOption("expose-remote-exception-types", False)
-
-        # see #521 for a discussion of how to pick these timeout values.
-        keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "")
-        if keepalive_timeout_s:
-            self.tub.setOption("keepaliveTimeout", int(keepalive_timeout_s))
-        disconnect_timeout_s = self.get_config("node", "timeout.disconnect", "")
-        if disconnect_timeout_s:
-            # N.B.: this is in seconds, so use "1800" to get 30min
-            self.tub.setOption("disconnectTimeout", int(disconnect_timeout_s))
-
-        self.nodeid = b32decode(self.tub.tubID.upper()) # binary format
-        self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n")
-        self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing
-
-        tubport = self.get_config("node", "tub.port", "tcp:0")
-        self.tub.listenOn(tubport)
-        # we must wait until our service has started before we can find out
-        # our IP address and thus do tub.setLocation, and we can't register
-        # any services with the Tub until after that point
-        self.tub.setServiceParent(self)
-
-    def setup_ssh(self):
-        ssh_port = self.get_config("node", "ssh.port", "")
-        if ssh_port:
-            ssh_keyfile_config = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8')
-            ssh_keyfile = abspath_expanduser_unicode(ssh_keyfile_config, base=self.basedir)
-            from allmydata import manhole
-            m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile)
-            m.setServiceParent(self)
-            self.log("AuthorizedKeysManhole listening on %s" % (ssh_port,))
-
-    def get_app_versions(self):
-        # TODO: merge this with allmydata.get_package_versions
-        return dict(app_versions.versions)
-
     def get_optional_config_from_file(self, path):
         """Read the (string) contents of a file. Any leading or trailing
         whitespace will be stripped from the data. If the file does not exist,
@@ -317,6 +221,113 @@ class Node(service.MultiService):
             self.log("Unable to write config file '%s'" % fn)
             self.log(e)
 
+
+class ConfigOnly(object, ConfigMixin):
+    GENERATED_FILES = []
+
+    def __init__(self, basedir=u"."):
+        self.basedir = abspath_expanduser_unicode(unicode(basedir))
+        self.read_config()
+
+
+class Node(service.MultiService, ConfigMixin):
+    # this implements common functionality of both Client nodes and Introducer
+    # nodes.
+    NODETYPE = "unknown NODETYPE"
+    PORTNUMFILE = None
+    CERTFILE = "node.pem"
+    GENERATED_FILES = []
+
+    def __init__(self, basedir=u"."):
+        service.MultiService.__init__(self)
+        self.basedir = abspath_expanduser_unicode(unicode(basedir))
+        self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE)
+        self._tub_ready_observerlist = observer.OneShotObserverList()
+        fileutil.make_dirs(os.path.join(self.basedir, "private"), 0700)
+        fileutil.write(os.path.join(self.basedir, "private", "README"), PRIV_README, mode="")
+
+        # creates self.config
+        self.read_config()
+
+        cfg_tubport = self.get_config("node", "tub.port", "")
+        if not cfg_tubport:
+            # For 'tub.port', tahoe.cfg overrides the individual file on
+            # disk. So only read self._portnumfile if tahoe.cfg doesn't
+            # provide a value.
+            try:
+                file_tubport = fileutil.read(self._portnumfile).strip()
+                self.set_config("node", "tub.port", file_tubport)
+            except EnvironmentError:
+                if os.path.exists(self._portnumfile):
+                    raise
+
+        nickname_utf8 = self.get_config("node", "nickname", "<unspecified>")
+        self.nickname = nickname_utf8.decode("utf-8")
+        assert type(self.nickname) is unicode
+
+        self.init_tempdir()
+        self.create_tub()
+        self.logSource="Node"
+
+        self.setup_ssh()
+        self.setup_logging()
+        self.log("Node constructed. " + get_package_versions_string())
+        iputil.increase_rlimits()
+
+    def init_tempdir(self):
+        tempdir_config = self.get_config("node", "tempdir", "tmp").decode('utf-8')
+        tempdir = abspath_expanduser_unicode(tempdir_config, base=self.basedir)
+        if not os.path.exists(tempdir):
+            fileutil.make_dirs(tempdir)
+        tempfile.tempdir = tempdir
+        # this should cause twisted.web.http (which uses
+        # tempfile.TemporaryFile) to put large request bodies in the given
+        # directory. Without this, the default temp dir is usually /tmp/,
+        # which is frequently too small.
+        test_name = tempfile.mktemp()
+        _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir)
+
+    def create_tub(self):
+        certfile = os.path.join(self.basedir, "private", self.CERTFILE)
+        self.tub = Tub(certFile=certfile)
+        self.tub.setOption("logLocalFailures", True)
+        self.tub.setOption("logRemoteFailures", True)
+        self.tub.setOption("expose-remote-exception-types", False)
+
+        # see #521 for a discussion of how to pick these timeout values.
+        keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "")
+        if keepalive_timeout_s:
+            self.tub.setOption("keepaliveTimeout", int(keepalive_timeout_s))
+        disconnect_timeout_s = self.get_config("node", "timeout.disconnect", "")
+        if disconnect_timeout_s:
+            # N.B.: this is in seconds, so use "1800" to get 30min
+            self.tub.setOption("disconnectTimeout", int(disconnect_timeout_s))
+
+        self.nodeid = b32decode(self.tub.tubID.upper()) # binary format
+        self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n")
+        self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing
+
+        tubport = self.get_config("node", "tub.port", "tcp:0")
+        self.tub.listenOn(tubport)
+        # we must wait until our service has started before we can find out
+        # our IP address and thus do tub.setLocation, and we can't register
+        # any services with the Tub until after that point
+        self.tub.setServiceParent(self)
+
+    def setup_ssh(self):
+        ssh_port = self.get_config("node", "ssh.port", "")
+        if ssh_port:
+            ssh_keyfile_config = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8')
+            ssh_keyfile = abspath_expanduser_unicode(ssh_keyfile_config, base=self.basedir)
+            from allmydata import manhole
+            m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile)
+            m.setServiceParent(self)
+            self.log("AuthorizedKeysManhole listening on %s" % (ssh_port,))
+
+    def get_app_versions(self):
+        # TODO: merge this with allmydata.get_package_versions
+        return dict(app_versions.versions)
+
     def startService(self):
         # Note: this class can be started and stopped at most once.
         self.log("Node.startService")