From: Brian Warner <warner@allmydata.com>
Date: Thu, 20 Sep 2007 22:33:58 +0000 (-0700)
Subject: check_memory: fix race condition for startup of in-process server nodes
X-Git-Tag: allmydata-tahoe-0.6.0~43
X-Git-Url: https://git.rkrishnan.org/vdrive/%22news.html/simplejson/FOOURL?a=commitdiff_plain;h=3774ce59ea87abd50feca23d134637dd56d3d438;p=tahoe-lafs%2Ftahoe-lafs.git

check_memory: fix race condition for startup of in-process server nodes
---

diff --git a/src/allmydata/client.py b/src/allmydata/client.py
index d246f90f..9f0dec4d 100644
--- a/src/allmydata/client.py
+++ b/src/allmydata/client.py
@@ -17,9 +17,9 @@ from allmydata.download import Downloader
 from allmydata.control import ControlServer
 from allmydata.introducer import IntroducerClient
 from allmydata.vdrive import VirtualDrive
-from allmydata.util import hashutil, idlib
+from allmydata.util import hashutil, idlib, testutil
 
-class Client(node.Node, Referenceable):
+class Client(node.Node, Referenceable, testutil.PollMixin):
     implements(RIClient)
     PORTNUMFILE = "client.port"
     STOREDIR = 'storage'
@@ -189,3 +189,16 @@ class Client(node.Node, Referenceable):
 
     def get_cancel_secret(self):
         return hashutil.my_cancel_secret_hash(self._secret)
+
+    def debug_wait_for_client_connections(self, num_clients):
+        """Return a Deferred that fires (with None) when we have connections
+        to the given number of peers. Useful for tests that set up a
+        temporary test network and need to know when it is safe to proceed
+        with an upload or download."""
+        def _check():
+            current_clients = list(self.get_all_peerids())
+            return len(current_clients) >= num_clients
+        d = self.poll(_check, 0.5)
+        d.addCallback(lambda res: None)
+        return d
+
diff --git a/src/allmydata/control.py b/src/allmydata/control.py
index 69f4410e..40d3a3d8 100644
--- a/src/allmydata/control.py
+++ b/src/allmydata/control.py
@@ -33,12 +33,7 @@ class ControlServer(Referenceable, service.Service, testutil.PollMixin):
     implements(RIControlClient)
 
     def remote_wait_for_client_connections(self, num_clients):
-        def _check():
-            current_clients = list(self.parent.get_all_peerids())
-            return len(current_clients) >= num_clients
-        d = self.poll(_check, 0.5)
-        d.addCallback(lambda res: None)
-        return d
+        return self.parent.debug_wait_for_client_connections(num_clients)
 
     def remote_upload_from_file_to_uri(self, filename):
         uploader = self.parent.getServiceNamed("uploader")
@@ -52,9 +47,6 @@ class ControlServer(Referenceable, service.Service, testutil.PollMixin):
         return d
 
     def remote_upload_speed_test(self, size):
-        """Write a tempfile to disk of the given size. Measure how long
-        it takes to upload it to the servers.
-        """
         assert size > 8
         fn = os.path.join(self.parent.basedir, idlib.b2a(os.urandom(8)))
         f = open(fn, "w")
diff --git a/src/allmydata/test/check_memory.py b/src/allmydata/test/check_memory.py
index d75e32b6..c62cfd84 100644
--- a/src/allmydata/test/check_memory.py
+++ b/src/allmydata/test/check_memory.py
@@ -377,18 +377,21 @@ this file are ignored.
             data = "a" * size
             url = "/vdrive/global"
             d = self.POST(url, t="upload", file=("%d.data" % size, data))
-        elif self.mode in ("receive",):
-            # upload the data from a local peer, so that the
+        elif self.mode in ("receive",
+                           "download", "download-GET", "download-GET-slow"):
+            # mode=receive: upload the data from a local peer, so that the
             # client-under-test receives and stores the shares
+            #
+            # mode=download*: upload the data from a local peer, then have
+            # the client-under-test download it.
+            #
+            # we need to wait until the uploading node has connected to all
+            # peers, since the wait_for_client_connections() above doesn't
+            # pay attention to our self.nodes[] and their connections.
             files[name] = self.create_data(name, size)
             u = self.nodes[0].getServiceNamed("uploader")
-            d = u.upload_filename(files[name])
-        elif self.mode in ("download", "download-GET", "download-GET-slow"):
-            # upload the data from a local peer, then have the
-            # client-under-test download it.
-            files[name] = self.create_data(name, size)
-            u = self.nodes[0].getServiceNamed("uploader")
-            d = u.upload_filename(files[name])
+            d = self.nodes[0].debug_wait_for_client_connections(self.numnodes+1)
+            d.addCallback(lambda res: u.upload_filename(files[name]))
         else:
             raise RuntimeError("unknown mode=%s" % self.mode)
         def _complete(uri):