]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
node.py: refactor config-file getting and setting
authorBrian Warner <warner@lothar.com>
Tue, 28 Aug 2007 01:58:39 +0000 (18:58 -0700)
committerBrian Warner <warner@lothar.com>
Tue, 28 Aug 2007 01:58:39 +0000 (18:58 -0700)
src/allmydata/client.py
src/allmydata/introducer_and_vdrive.py
src/allmydata/node.py

index c2d74e46faf8a7ced30c3462614c8e4179eaa929..f279df108fd111d2a38b81fb9c301a248df85de6 100644 (file)
@@ -24,13 +24,7 @@ class Client(node.Node, Referenceable):
     PORTNUMFILE = "client.port"
     STOREDIR = 'storage'
     NODETYPE = "client"
-    WEBPORTFILE = "webport"
-    WEB_ALLOW_LOCAL_ACCESS_FILE = "webport_allow_localfile"
-    INTRODUCER_FURL_FILE = "introducer.furl"
-    MY_FURL_FILE = "myself.furl"
     SUICIDE_PREVENTION_HOTLINE_FILE = "suicide_prevention_hotline"
-    SIZELIMIT_FILE = "sizelimit"
-    PUSH_TO_OURSELVES_FILE = "push_to_ourselves"
 
     # we're pretty narrow-minded right now
     OLDEST_SUPPORTED_VERSION = allmydata.__version__
@@ -45,17 +39,11 @@ class Client(node.Node, Referenceable):
         self.add_service(Uploader())
         self.add_service(Downloader())
         self.add_service(VirtualDrive())
-        try:
-            webport = open(os.path.join(self.basedir, self.WEBPORTFILE),
-                           "r").read().strip() # strports string
-        except EnvironmentError:
-            pass # absent or unreadable webport file
-        else:
-            self.init_web(webport)
-
-        INTRODUCER_FURL_FILE = os.path.join(self.basedir,
-                                            self.INTRODUCER_FURL_FILE)
-        self.introducer_furl = open(INTRODUCER_FURL_FILE, "r").read().strip()
+        webport = self.get_config("webport")
+        if webport:
+            self.init_web(webport) # strports string
+
+        self.introducer_furl = self.get_config("introducer.furl", required=True)
 
         hotline_file = os.path.join(self.basedir,
                                     self.SUICIDE_PREVENTION_HOTLINE_FILE)
@@ -68,12 +56,8 @@ class Client(node.Node, Referenceable):
         storedir = os.path.join(self.basedir, self.STOREDIR)
         sizelimit = None
 
-        try:
-            data = open(os.path.join(self.basedir, self.SIZELIMIT_FILE),
-                        "r").read().strip()
-        except EnvironmentError:
-            pass # absent or unreadable sizelimit file
-        else:
+        data = self.get_config("sizelimit")
+        if data:
             m = re.match(r"^(\d+)([kKmMgG]?[bB]?)$", data)
             if not m:
                 log.msg("SIZELIMIT_FILE contains unparseable value %s" % data)
@@ -88,21 +72,19 @@ class Client(node.Node, Referenceable):
                               "G": 1000 * 1000 * 1000,
                               }[suffix]
                 sizelimit = int(number) * multiplier
-        NOSTORAGE_FILE = os.path.join(self.basedir, "debug_no_storage")
-        no_storage = os.path.exists(NOSTORAGE_FILE)
+        no_storage = self.get_config("debug_no_storage") is not None
         self.add_service(StorageServer(storedir, sizelimit, no_storage))
 
     def init_options(self):
         self.push_to_ourselves = None
-        filename = os.path.join(self.basedir, self.PUSH_TO_OURSELVES_FILE)
-        if os.path.exists(filename):
+        if self.get_config("push_to_ourselves") is not None:
             self.push_to_ourselves = True
 
     def init_web(self, webport):
         # this must be called after the VirtualDrive is attached
         ws = WebishServer(webport)
-        ws.allow_local_access(os.path.exists(os.path.join(self.basedir,
-                              self.WEB_ALLOW_LOCAL_ACCESS_FILE)))
+        if self.get_config("webport_allow_localfile") is not None:
+            ws.allow_local_access(True)
         self.add_service(ws)
         vd = self.getServiceNamed("vdrive")
         startfile = os.path.join(self.basedir, "start.html")
@@ -122,18 +104,13 @@ class Client(node.Node, Referenceable):
         self.log("tub_ready")
 
         my_old_name = None
-        try:
-            my_old_furl = open(os.path.join(self.basedir, self.MY_FURL_FILE),
-                               "r").read().strip()
-        except EnvironmentError:
-            pass # absent or unreadable myfurl file
-        else:
+        my_old_furl = self.get_config("myself.furl")
+        if my_old_furl is not None:
             sturdy = SturdyRef(my_old_furl)
             my_old_name = sturdy.name
 
         self.my_furl = self.tub.registerReference(self, my_old_name)
-        open(os.path.join(self.basedir, self.MY_FURL_FILE),
-             "w").write(self.my_furl + "\n")
+        self.write_config("myself.furl", self.my_furl + "\n")
 
         ic = IntroducerClient(self.tub, self.introducer_furl, self.my_furl)
         self.introducer_client = ic
index 27e83835183bed5c0d654de34e6d2fa66cc6df3e..8aa2e0aba6caf5ba73d8a2cedf548ab528d3b465 100644 (file)
@@ -22,9 +22,7 @@ class IntroducerAndVdrive(node.Node):
         r = self.add_service(i)
         self.urls["introducer"] = self.tub.registerReference(r, "introducer")
         self.log(" introducer is at %s" % self.urls["introducer"])
-        f = open(os.path.join(self.basedir, "introducer.furl"), "w")
-        f.write(self.urls["introducer"] + "\n")
-        f.close()
+        self.write_config("introducer.furl", self.urls["introducer"] + "\n")
 
         vdrive_dir = os.path.join(self.basedir, self.VDRIVEDIR)
         vds = self.add_service(VirtualDriveServer(vdrive_dir))
@@ -32,20 +30,15 @@ class IntroducerAndVdrive(node.Node):
         vds.set_furl(vds_furl)
         self.urls["vdrive"] = vds_furl
         self.log(" vdrive is at %s" % self.urls["vdrive"])
-        f = open(os.path.join(self.basedir, "vdrive.furl"), "w")
-        f.write(self.urls["vdrive"] + "\n")
-        f.close()
+        self.write_config("vdrive.furl", self.urls["vdrive"] + "\n")
 
         encoding_parameters = self.read_encoding_parameters()
         i.set_encoding_parameters(encoding_parameters)
 
     def read_encoding_parameters(self):
         k, desired, n = self.DEFAULT_K, self.DEFAULT_DESIRED, self.DEFAULT_N
-        PARAM_FILE = os.path.join(self.basedir, self.ENCODING_PARAMETERS_FILE)
-        if os.path.exists(PARAM_FILE):
-            f = open(PARAM_FILE, "r")
-            data = f.read().strip()
-            f.close()
+        data = self.get_config("encoding_parameters")
+        if data is not None:
             k,desired,n = data.split()
             k = int(k); desired = int(desired); n = int(n)
         return k, desired, n
index 38c0dd715b93a313743c4934cbd41ca8d8114743..279d8e8089fe79268f634b0d68e08b9a68b76632 100644 (file)
@@ -25,7 +25,6 @@ class Node(service.MultiService):
     PORTNUMFILE = None
     CERTFILE = "node.pem"
     LOCAL_IP_FILE = "advertised_ip_addresses"
-    NODEIDFILE = "my_nodeid"
 
     def __init__(self, basedir="."):
         service.MultiService.__init__(self)
@@ -36,9 +35,7 @@ class Node(service.MultiService):
         self.tub.setOption("logLocalFailures", True)
         self.tub.setOption("logRemoteFailures", True)
         self.nodeid = b32decode(self.tub.tubID.upper()) # binary format
-        f = open(os.path.join(self.basedir, self.NODEIDFILE), "w")
-        f.write(b32encode(self.nodeid).lower() + "\n")
-        f.close()
+        self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n")
         self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing
         assert self.PORTNUMFILE, "Your node.Node subclass must provide PORTNUMFILE"
         self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE)
@@ -68,6 +65,44 @@ class Node(service.MultiService):
                  % (allmydata.__version__, foolscap.__version__,
                     twisted.__version__, zfec.__version__,))
 
+    def get_config(self, name, mode="r", required=False):
+        """Get the (string) contents of a config file, or None if the file
+        did not exist. If required=True, raise an exception rather than
+        returning None. Any leading or trailing whitespace will be stripped
+        from the data."""
+        fn = os.path.join(self.basedir, name)
+        try:
+            return open(fn, mode).read().strip()
+        except EnvironmentError:
+            if not required:
+                return None
+            raise
+
+    def get_or_create_config(self, name, default, mode="w"):
+        """Try to get the (string) contents of a config file. If the file
+        does not exist, create it with the given default value, and return
+        the default value. Any leading or trailing whitespace will be
+        stripped from the data."""
+        value = self.get_config(name)
+        if value is None:
+            value = default
+            fn = os.path.join(self.basedir, name)
+            try:
+                open(fn, mode).write(value)
+            except EnvironmentError, e:
+                self.log("Unable to write config file '%s'" % fn)
+                self.log(e)
+        return value
+
+    def write_config(self, name, value, mode="w"):
+        """Write a string to a config file."""
+        fn = os.path.join(self.basedir, name)
+        try:
+            open(fn, mode).write(value)
+        except EnvironmentError, e:
+            self.log("Unable to write config file '%s'" % fn)
+            self.log(e)
+
     def get_versions(self):
         return {'allmydata': allmydata.__version__,
                 'foolscap': foolscap.__version__,