]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
check_memory: fix race condition for startup of in-process server nodes
authorBrian Warner <warner@allmydata.com>
Thu, 20 Sep 2007 22:33:58 +0000 (15:33 -0700)
committerBrian Warner <warner@allmydata.com>
Thu, 20 Sep 2007 22:33:58 +0000 (15:33 -0700)
src/allmydata/client.py
src/allmydata/control.py
src/allmydata/test/check_memory.py

index d246f90fe697b5391010aeb5b64b87907c98b061..9f0dec4d711bd14c6029b72722c94f62041dd1d6 100644 (file)
@@ -17,9 +17,9 @@ from allmydata.download import Downloader
 from allmydata.control import ControlServer
 from allmydata.introducer import IntroducerClient
 from allmydata.vdrive import VirtualDrive
-from allmydata.util import hashutil, idlib
+from allmydata.util import hashutil, idlib, testutil
 
-class Client(node.Node, Referenceable):
+class Client(node.Node, Referenceable, testutil.PollMixin):
     implements(RIClient)
     PORTNUMFILE = "client.port"
     STOREDIR = 'storage'
@@ -189,3 +189,16 @@ class Client(node.Node, Referenceable):
 
     def get_cancel_secret(self):
         return hashutil.my_cancel_secret_hash(self._secret)
+
+    def debug_wait_for_client_connections(self, num_clients):
+        """Return a Deferred that fires (with None) when we have connections
+        to the given number of peers. Useful for tests that set up a
+        temporary test network and need to know when it is safe to proceed
+        with an upload or download."""
+        def _check():
+            current_clients = list(self.get_all_peerids())
+            return len(current_clients) >= num_clients
+        d = self.poll(_check, 0.5)
+        d.addCallback(lambda res: None)
+        return d
+
index 69f4410e7c2f779f4819ec5e1092f7c32d3a3b45..40d3a3d883b47f1f824d12fca93b269d3ed260b5 100644 (file)
@@ -33,12 +33,7 @@ class ControlServer(Referenceable, service.Service, testutil.PollMixin):
     implements(RIControlClient)
 
     def remote_wait_for_client_connections(self, num_clients):
-        def _check():
-            current_clients = list(self.parent.get_all_peerids())
-            return len(current_clients) >= num_clients
-        d = self.poll(_check, 0.5)
-        d.addCallback(lambda res: None)
-        return d
+        return self.parent.debug_wait_for_client_connections(num_clients)
 
     def remote_upload_from_file_to_uri(self, filename):
         uploader = self.parent.getServiceNamed("uploader")
@@ -52,9 +47,6 @@ class ControlServer(Referenceable, service.Service, testutil.PollMixin):
         return d
 
     def remote_upload_speed_test(self, size):
-        """Write a tempfile to disk of the given size. Measure how long
-        it takes to upload it to the servers.
-        """
         assert size > 8
         fn = os.path.join(self.parent.basedir, idlib.b2a(os.urandom(8)))
         f = open(fn, "w")
index d75e32b60f75907dad2a18a1cb6e6112f6166115..c62cfd84d6801d1c5935bb35e7c6695e305eea49 100644 (file)
@@ -377,18 +377,21 @@ this file are ignored.
             data = "a" * size
             url = "/vdrive/global"
             d = self.POST(url, t="upload", file=("%d.data" % size, data))
-        elif self.mode in ("receive",):
-            # upload the data from a local peer, so that the
+        elif self.mode in ("receive",
+                           "download", "download-GET", "download-GET-slow"):
+            # mode=receive: upload the data from a local peer, so that the
             # client-under-test receives and stores the shares
+            #
+            # mode=download*: upload the data from a local peer, then have
+            # the client-under-test download it.
+            #
+            # we need to wait until the uploading node has connected to all
+            # peers, since the wait_for_client_connections() above doesn't
+            # pay attention to our self.nodes[] and their connections.
             files[name] = self.create_data(name, size)
             u = self.nodes[0].getServiceNamed("uploader")
-            d = u.upload_filename(files[name])
-        elif self.mode in ("download", "download-GET", "download-GET-slow"):
-            # upload the data from a local peer, then have the
-            # client-under-test download it.
-            files[name] = self.create_data(name, size)
-            u = self.nodes[0].getServiceNamed("uploader")
-            d = u.upload_filename(files[name])
+            d = self.nodes[0].debug_wait_for_client_connections(self.numnodes+1)
+            d.addCallback(lambda res: u.upload_filename(files[name]))
         else:
             raise RuntimeError("unknown mode=%s" % self.mode)
         def _complete(uri):