From: Daira Hopwood Date: Tue, 4 Aug 2015 18:10:24 +0000 (+0100) Subject: Refactoring to make node config accessible without actually creating a Node. refs... X-Git-Url: https://git.rkrishnan.org/somewhere?a=commitdiff_plain;h=8fb27eadb46010d7e121b8216ba3765709334a25;p=tahoe-lafs%2Ftahoe-lafs.git Refactoring to make node config accessible without actually creating a Node. refs #1971 Signed-off-by: Daira Hopwood --- diff --git a/src/allmydata/node.py b/src/allmydata/node.py index 8d1ca15e..0bafc2b3 100644 --- a/src/allmydata/node.py +++ b/src/allmydata/node.py @@ -67,61 +67,7 @@ class UnescapedHashError(Exception): % quote_output("[%s]%s = %s" % self.args)) -class Node(service.MultiService): - # this implements common functionality of both Client nodes and Introducer - # nodes. - NODETYPE = "unknown NODETYPE" - PORTNUMFILE = None - CERTFILE = "node.pem" - GENERATED_FILES = [] - - def __init__(self, basedir=u"."): - service.MultiService.__init__(self) - self.basedir = abspath_expanduser_unicode(unicode(basedir)) - self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE) - self._tub_ready_observerlist = observer.OneShotObserverList() - fileutil.make_dirs(os.path.join(self.basedir, "private"), 0700) - fileutil.write(os.path.join(self.basedir, "private", "README"), PRIV_README, mode="") - - # creates self.config - self.read_config() - nickname_utf8 = self.get_config("node", "nickname", "") - self.nickname = nickname_utf8.decode("utf-8") - assert type(self.nickname) is unicode - - self.init_tempdir() - self.create_tub() - self.logSource="Node" - - self.setup_ssh() - self.setup_logging() - self.log("Node constructed. " + get_package_versions_string()) - iputil.increase_rlimits() - - def init_tempdir(self): - tempdir_config = self.get_config("node", "tempdir", "tmp").decode('utf-8') - tempdir = abspath_expanduser_unicode(tempdir_config, base=self.basedir) - if not os.path.exists(tempdir): - fileutil.make_dirs(tempdir) - tempfile.tempdir = tempdir - # this should cause twisted.web.http (which uses - # tempfile.TemporaryFile) to put large request bodies in the given - # directory. Without this, the default temp dir is usually /tmp/, - # which is frequently too small. - test_name = tempfile.mktemp() - _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir) - - @staticmethod - def _contains_unescaped_hash(item): - characters = iter(item) - for c in characters: - if c == '\\': - characters.next() - elif c == '#': - return True - - return False - +class ConfigMixin: def get_config(self, section, option, default=_None, boolean=False): try: if boolean: @@ -174,18 +120,6 @@ class Node(service.MultiService): if os.path.exists(tahoe_cfg): raise - cfg_tubport = self.get_config("node", "tub.port", "") - if not cfg_tubport: - # For 'tub.port', tahoe.cfg overrides the individual file on - # disk. So only read self._portnumfile if tahoe.cfg doesn't - # provide a value. - try: - file_tubport = fileutil.read(self._portnumfile).strip() - self.set_config("node", "tub.port", file_tubport) - except EnvironmentError: - if os.path.exists(self._portnumfile): - raise - def error_about_old_config_files(self): """ If any old configuration files are detected, raise OldConfigError. """ @@ -205,47 +139,6 @@ class Node(service.MultiService): twlog.msg(e) raise e - def create_tub(self): - certfile = os.path.join(self.basedir, "private", self.CERTFILE) - self.tub = Tub(certFile=certfile) - self.tub.setOption("logLocalFailures", True) - self.tub.setOption("logRemoteFailures", True) - self.tub.setOption("expose-remote-exception-types", False) - - # see #521 for a discussion of how to pick these timeout values. - keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "") - if keepalive_timeout_s: - self.tub.setOption("keepaliveTimeout", int(keepalive_timeout_s)) - disconnect_timeout_s = self.get_config("node", "timeout.disconnect", "") - if disconnect_timeout_s: - # N.B.: this is in seconds, so use "1800" to get 30min - self.tub.setOption("disconnectTimeout", int(disconnect_timeout_s)) - - self.nodeid = b32decode(self.tub.tubID.upper()) # binary format - self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n") - self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing - - tubport = self.get_config("node", "tub.port", "tcp:0") - self.tub.listenOn(tubport) - # we must wait until our service has started before we can find out - # our IP address and thus do tub.setLocation, and we can't register - # any services with the Tub until after that point - self.tub.setServiceParent(self) - - def setup_ssh(self): - ssh_port = self.get_config("node", "ssh.port", "") - if ssh_port: - ssh_keyfile_config = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8') - ssh_keyfile = abspath_expanduser_unicode(ssh_keyfile_config, base=self.basedir) - from allmydata import manhole - m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile) - m.setServiceParent(self) - self.log("AuthorizedKeysManhole listening on %s" % (ssh_port,)) - - def get_app_versions(self): - # TODO: merge this with allmydata.get_package_versions - return dict(app_versions.versions) - def get_optional_config_from_file(self, path): """Read the (string) contents of a file. Any leading or trailing whitespace will be stripped from the data. If the file does not exist, @@ -317,6 +210,124 @@ class Node(service.MultiService): self.log("Unable to write config file '%s'" % fn) self.log(e) + +class ConfigOnly(object, ConfigMixin): + GENERATED_FILES = [] + + def __init__(self, basedir=u"."): + self.basedir = abspath_expanduser_unicode(unicode(basedir)) + self.read_config() + + +class Node(service.MultiService, ConfigMixin): + # this implements common functionality of both Client nodes and Introducer + # nodes. + NODETYPE = "unknown NODETYPE" + PORTNUMFILE = None + CERTFILE = "node.pem" + GENERATED_FILES = [] + + def __init__(self, basedir=u"."): + service.MultiService.__init__(self) + self.basedir = abspath_expanduser_unicode(unicode(basedir)) + self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE) + self._tub_ready_observerlist = observer.OneShotObserverList() + fileutil.make_dirs(os.path.join(self.basedir, "private"), 0700) + fileutil.write(os.path.join(self.basedir, "private", "README"), PRIV_README, mode="") + + # creates self.config + self.read_config() + + cfg_tubport = self.get_config("node", "tub.port", "") + if not cfg_tubport: + # For 'tub.port', tahoe.cfg overrides the individual file on + # disk. So only read self._portnumfile if tahoe.cfg doesn't + # provide a value. + try: + file_tubport = fileutil.read(self._portnumfile).strip() + self.set_config("node", "tub.port", file_tubport) + except EnvironmentError: + if os.path.exists(self._portnumfile): + raise + + nickname_utf8 = self.get_config("node", "nickname", "") + self.nickname = nickname_utf8.decode("utf-8") + assert type(self.nickname) is unicode + + self.init_tempdir() + self.create_tub() + self.logSource="Node" + + self.setup_ssh() + self.setup_logging() + self.log("Node constructed. " + get_package_versions_string()) + iputil.increase_rlimits() + + def init_tempdir(self): + tempdir_config = self.get_config("node", "tempdir", "tmp").decode('utf-8') + tempdir = abspath_expanduser_unicode(tempdir_config, base=self.basedir) + if not os.path.exists(tempdir): + fileutil.make_dirs(tempdir) + tempfile.tempdir = tempdir + # this should cause twisted.web.http (which uses + # tempfile.TemporaryFile) to put large request bodies in the given + # directory. Without this, the default temp dir is usually /tmp/, + # which is frequently too small. + test_name = tempfile.mktemp() + _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir) + + @staticmethod + def _contains_unescaped_hash(item): + characters = iter(item) + for c in characters: + if c == '\\': + characters.next() + elif c == '#': + return True + + return False + + def create_tub(self): + certfile = os.path.join(self.basedir, "private", self.CERTFILE) + self.tub = Tub(certFile=certfile) + self.tub.setOption("logLocalFailures", True) + self.tub.setOption("logRemoteFailures", True) + self.tub.setOption("expose-remote-exception-types", False) + + # see #521 for a discussion of how to pick these timeout values. + keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "") + if keepalive_timeout_s: + self.tub.setOption("keepaliveTimeout", int(keepalive_timeout_s)) + disconnect_timeout_s = self.get_config("node", "timeout.disconnect", "") + if disconnect_timeout_s: + # N.B.: this is in seconds, so use "1800" to get 30min + self.tub.setOption("disconnectTimeout", int(disconnect_timeout_s)) + + self.nodeid = b32decode(self.tub.tubID.upper()) # binary format + self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n") + self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing + + tubport = self.get_config("node", "tub.port", "tcp:0") + self.tub.listenOn(tubport) + # we must wait until our service has started before we can find out + # our IP address and thus do tub.setLocation, and we can't register + # any services with the Tub until after that point + self.tub.setServiceParent(self) + + def setup_ssh(self): + ssh_port = self.get_config("node", "ssh.port", "") + if ssh_port: + ssh_keyfile_config = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8') + ssh_keyfile = abspath_expanduser_unicode(ssh_keyfile_config, base=self.basedir) + from allmydata import manhole + m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile) + m.setServiceParent(self) + self.log("AuthorizedKeysManhole listening on %s" % (ssh_port,)) + + def get_app_versions(self): + # TODO: merge this with allmydata.get_package_versions + return dict(app_versions.versions) + def startService(self): # Note: this class can be started and stopped at most once. self.log("Node.startService")