]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
rearrange service startup a bit, now Node.startService() returns a Deferred that...
authorBrian Warner <warner@lothar.com>
Thu, 8 Mar 2007 22:10:36 +0000 (15:10 -0700)
committerBrian Warner <warner@lothar.com>
Thu, 8 Mar 2007 22:10:36 +0000 (15:10 -0700)
src/allmydata/client.py
src/allmydata/node.py
src/allmydata/test/test_client.py
src/allmydata/test/test_queen.py
src/allmydata/test/test_storage.py
src/allmydata/test/test_system.py

index 53956915a89734c91ce2097b6a9c58a906cd5706..f0b3927b981c8dfc02f970180905b932c595a969 100644 (file)
@@ -29,6 +29,7 @@ class Client(node.Node, Referenceable):
     def __init__(self, basedir="."):
         node.Node.__init__(self, basedir)
         self.queen = None # self.queen is either None or a RemoteReference
+        self.my_pburl = None
         self.all_peers = set()
         self.peer_pburls = {}
         self.connections = {}
index db45ac9324988bdcee00ac4b037e3dfd2380e1dc..da656a1900ccc12c1c3e839541a9dcc70663e726 100644 (file)
@@ -1,10 +1,11 @@
 
-from twisted.application import service
 import os.path
+from twisted.python import log
+from twisted.application import service
+from twisted.internet import defer
 from foolscap import Tub
 from allmydata.util.iputil import get_local_addresses
-from allmydata.util import idlib
-from twisted.python import log
+from allmydata.util import idlib, observer
 
 class Node(service.MultiService):
     # this implements common functionality of both Client nodes and the Queen
@@ -18,6 +19,7 @@ class Node(service.MultiService):
     def __init__(self, basedir="."):
         service.MultiService.__init__(self)
         self.basedir = os.path.abspath(basedir)
+        self._tub_ready_observerlist = observer.OneShotObserverList()
         assert self.CERTFILE, "Your node.Node subclass must provide CERTFILE"
         certfile = os.path.join(self.basedir, self.CERTFILE)
         if os.path.exists(certfile):
@@ -55,6 +57,37 @@ class Node(service.MultiService):
                 m.setServiceParent(self)
                 self.log("AuthorizedKeysManhole listening on %d" % portnum)
 
+    def startService(self):
+        """Start the node. Returns a Deferred that fires (with self) when it
+        is ready to go.
+
+        Many callers don't pay attention to the return value from
+        startService, since they aren't going to do anything special when it
+        finishes. If they are (for example unit tests which need to wait for
+        the node to fully start up before it gets shut down), they can wait
+        for the Deferred I return to fire. In particular, you should wait for
+        my startService() Deferred to fire before you call my stopService()
+        method.
+        """
+
+        # note: this class can only be started and stopped once.
+        service.MultiService.startService(self)
+        d = defer.succeed(None)
+        d.addCallback(lambda res: get_local_addresses())
+        d.addCallback(self._setup_tub)
+        d.addCallback(lambda res: self.tub_ready())
+        def _ready(res):
+            self.log("%s running" % self.NODETYPE)
+            self._tub_ready_observerlist.fire(self)
+            return self
+        d.addCallback(_ready)
+        return d
+
+    def shutdown(self):
+        """Shut down the node. Returns a Deferred that fires (with None) when
+        it finally stops kicking."""
+        return self.stopService()
+
     def log(self, msg):
         log.msg(self.short_nodeid + ": " + msg)
 
@@ -86,15 +119,10 @@ class Node(service.MultiService):
         # called when the Tub is available for registerReference
         pass
 
+    def when_tub_ready(self):
+        return self._tub_ready_observerlist.when_fired()
+
     def add_service(self, s):
         s.setServiceParent(self)
         return s
 
-    def startService(self):
-        # note: this class can only be started and stopped once.
-        service.MultiService.startService(self)
-        local_addresses = get_local_addresses()
-        self._setup_tub(local_addresses)
-        self.tub_ready()
-        self.log("%s running" % self.NODETYPE)
-
index 5cd01e06ad75f47c8f6111cfea62067f086dc6a2..eee9dbc94c11ccbf2916f4d7642a2577f6c676b7 100644 (file)
@@ -6,8 +6,9 @@ from allmydata import client
 class Basic(unittest.TestCase):
     def test_loadable(self):
         c = client.Client("")
-        c.startService()
-        return c.stopService()
+        d = c.startService()
+        d.addCallback(lambda res: c.stopService())
+        return d
 
     def test_permute(self):
         c = client.Client("")
index 7e12917ad3a8ec19193c83a6b75b20a597907b07..0a242c4c75498b821a6291bb021eac3e85d58fc7 100644 (file)
@@ -6,5 +6,7 @@ from allmydata import queen
 class Basic(unittest.TestCase):
     def test_loadable(self):
         q = queen.Queen()
-        q.startService()
-        return q.stopService()
+        d = q.startService()
+        d.addCallback(lambda res: q.stopService())
+        return d
+
index dcc4035e3fb87ac21582098918d488823164972d..31da06ab662a51de8acf3761f15ccfeda20e29b0 100644 (file)
@@ -21,7 +21,8 @@ class StorageTest(unittest.TestCase):
         self.node.setServiceParent(self.svc)
         self.tub = Tub()
         self.tub.setServiceParent(self.svc)
-        return self.svc.startService()
+        self.svc.startService()
+        return self.node.when_tub_ready()
 
     def test_create_bucket(self):
         """
index 3ebc49f13d3fe37cda6b283b57aa9019276a5ea4..0ec50774b8904fa756141f599aac5ede65ce5b26 100644 (file)
@@ -40,10 +40,16 @@ class SystemTest(unittest.TestCase):
         self.numclients = NUMCLIENTS
         if not os.path.isdir("queen"):
             os.mkdir("queen")
-        q = self.queen = self.add_service(queen.Queen(basedir="queen"))
+        self.queen = self.add_service(queen.Queen(basedir="queen"))
+        d = self.queen.when_tub_ready()
+        d.addCallback(self._set_up_nodes_2)
+        return d
+
+    def _set_up_nodes_2(self, res):
+        q = self.queen
         self.queen_pburl = q.urls["roster"]
         self.clients = []
-        for i in range(NUMCLIENTS):
+        for i in range(self.numclients):
             basedir = "client%d" % i
             if not os.path.isdir(basedir):
                 os.mkdir(basedir)