From: Brian Warner Date: Thu, 20 Sep 2007 22:33:58 +0000 (-0700) Subject: check_memory: fix race condition for startup of in-process server nodes X-Git-Tag: allmydata-tahoe-0.6.0~43 X-Git-Url: https://git.rkrishnan.org/%5B/%5D%20/uri//%22%22?a=commitdiff_plain;h=3774ce59ea87abd50feca23d134637dd56d3d438;p=tahoe-lafs%2Ftahoe-lafs.git check_memory: fix race condition for startup of in-process server nodes --- diff --git a/src/allmydata/client.py b/src/allmydata/client.py index d246f90f..9f0dec4d 100644 --- a/src/allmydata/client.py +++ b/src/allmydata/client.py @@ -17,9 +17,9 @@ from allmydata.download import Downloader from allmydata.control import ControlServer from allmydata.introducer import IntroducerClient from allmydata.vdrive import VirtualDrive -from allmydata.util import hashutil, idlib +from allmydata.util import hashutil, idlib, testutil -class Client(node.Node, Referenceable): +class Client(node.Node, Referenceable, testutil.PollMixin): implements(RIClient) PORTNUMFILE = "client.port" STOREDIR = 'storage' @@ -189,3 +189,16 @@ class Client(node.Node, Referenceable): def get_cancel_secret(self): return hashutil.my_cancel_secret_hash(self._secret) + + def debug_wait_for_client_connections(self, num_clients): + """Return a Deferred that fires (with None) when we have connections + to the given number of peers. Useful for tests that set up a + temporary test network and need to know when it is safe to proceed + with an upload or download.""" + def _check(): + current_clients = list(self.get_all_peerids()) + return len(current_clients) >= num_clients + d = self.poll(_check, 0.5) + d.addCallback(lambda res: None) + return d + diff --git a/src/allmydata/control.py b/src/allmydata/control.py index 69f4410e..40d3a3d8 100644 --- a/src/allmydata/control.py +++ b/src/allmydata/control.py @@ -33,12 +33,7 @@ class ControlServer(Referenceable, service.Service, testutil.PollMixin): implements(RIControlClient) def remote_wait_for_client_connections(self, num_clients): - def _check(): - current_clients = list(self.parent.get_all_peerids()) - return len(current_clients) >= num_clients - d = self.poll(_check, 0.5) - d.addCallback(lambda res: None) - return d + return self.parent.debug_wait_for_client_connections(num_clients) def remote_upload_from_file_to_uri(self, filename): uploader = self.parent.getServiceNamed("uploader") @@ -52,9 +47,6 @@ class ControlServer(Referenceable, service.Service, testutil.PollMixin): return d def remote_upload_speed_test(self, size): - """Write a tempfile to disk of the given size. Measure how long - it takes to upload it to the servers. - """ assert size > 8 fn = os.path.join(self.parent.basedir, idlib.b2a(os.urandom(8))) f = open(fn, "w") diff --git a/src/allmydata/test/check_memory.py b/src/allmydata/test/check_memory.py index d75e32b6..c62cfd84 100644 --- a/src/allmydata/test/check_memory.py +++ b/src/allmydata/test/check_memory.py @@ -377,18 +377,21 @@ this file are ignored. data = "a" * size url = "/vdrive/global" d = self.POST(url, t="upload", file=("%d.data" % size, data)) - elif self.mode in ("receive",): - # upload the data from a local peer, so that the + elif self.mode in ("receive", + "download", "download-GET", "download-GET-slow"): + # mode=receive: upload the data from a local peer, so that the # client-under-test receives and stores the shares + # + # mode=download*: upload the data from a local peer, then have + # the client-under-test download it. + # + # we need to wait until the uploading node has connected to all + # peers, since the wait_for_client_connections() above doesn't + # pay attention to our self.nodes[] and their connections. files[name] = self.create_data(name, size) u = self.nodes[0].getServiceNamed("uploader") - d = u.upload_filename(files[name]) - elif self.mode in ("download", "download-GET", "download-GET-slow"): - # upload the data from a local peer, then have the - # client-under-test download it. - files[name] = self.create_data(name, size) - u = self.nodes[0].getServiceNamed("uploader") - d = u.upload_filename(files[name]) + d = self.nodes[0].debug_wait_for_client_connections(self.numnodes+1) + d.addCallback(lambda res: u.upload_filename(files[name])) else: raise RuntimeError("unknown mode=%s" % self.mode) def _complete(uri):