From b9c3359636cf949367feefa58592e1a1c4f843b1 Mon Sep 17 00:00:00 2001
From: Daira Hopwood <daira@jacaranda.org>
Date: Thu, 3 Sep 2015 18:29:45 +0100
Subject: [PATCH] Add a restart_client method to GridTestMixin.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
---
 src/allmydata/test/no_network.py | 65 +++++++++++++++++++++++---------
 1 file changed, 47 insertions(+), 18 deletions(-)

diff --git a/src/allmydata/test/no_network.py b/src/allmydata/test/no_network.py
index 7b7237bb..b84a7a2e 100644
--- a/src/allmydata/test/no_network.py
+++ b/src/allmydata/test/no_network.py
@@ -20,6 +20,9 @@ from twisted.internet import defer, reactor
 from twisted.python.failure import Failure
 from foolscap.api import Referenceable, fireEventually, RemoteException
 from base64 import b32encode
+
+from allmydata.util.assertutil import _assert
+
 from allmydata import uri as tahoe_uri
 from allmydata.client import Client
 from allmydata.storage.server import StorageServer, storage_index_to_dir
@@ -235,6 +238,7 @@ class NoNetworkGrid(service.MultiService):
         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
                                 # StorageServer
         self.clients = []
+        self.client_config_hooks = client_config_hooks
 
         for i in range(num_servers):
             ss = self.make_server(i)
@@ -242,30 +246,42 @@ class NoNetworkGrid(service.MultiService):
         self.rebuild_serverlist()
 
         for i in range(num_clients):
-            clientid = hashutil.tagged_hash("clientid", str(i))[:20]
-            clientdir = os.path.join(basedir, "clients",
-                                     idlib.shortnodeid_b2a(clientid))
-            fileutil.make_dirs(clientdir)
-            f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
+            c = self.make_client(i)
+            self.clients.append(c)
+
+    def make_client(self, i, write_config=True):
+        clientid = hashutil.tagged_hash("clientid", str(i))[:20]
+        clientdir = os.path.join(self.basedir, "clients",
+                                 idlib.shortnodeid_b2a(clientid))
+        fileutil.make_dirs(clientdir)
+
+        tahoe_cfg_path = os.path.join(clientdir, "tahoe.cfg")
+        if write_config:
+            f = open(tahoe_cfg_path, "w")
             f.write("[node]\n")
             f.write("nickname = client-%d\n" % i)
             f.write("web.port = tcp:0:interface=127.0.0.1\n")
             f.write("[storage]\n")
             f.write("enabled = false\n")
             f.close()
-            c = None
-            if i in client_config_hooks:
-                # this hook can either modify tahoe.cfg, or return an
-                # entirely new Client instance
-                c = client_config_hooks[i](clientdir)
-            if not c:
-                c = NoNetworkClient(clientdir)
-                c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
-            c.nodeid = clientid
-            c.short_nodeid = b32encode(clientid).lower()[:8]
-            c._servers = self.all_servers # can be updated later
-            c.setServiceParent(self)
-            self.clients.append(c)
+        else:
+            _assert(os.path.exists(tahoe_cfg_path), tahoe_cfg_path=tahoe_cfg_path)
+
+        c = None
+        if i in self.client_config_hooks:
+            # this hook can either modify tahoe.cfg, or return an
+            # entirely new Client instance
+            c = self.client_config_hooks[i](clientdir)
+
+        if not c:
+            c = NoNetworkClient(clientdir)
+            c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
+
+        c.nodeid = clientid
+        c.short_nodeid = b32encode(clientid).lower()[:8]
+        c._servers = self.all_servers # can be updated later
+        c.setServiceParent(self)
+        return c
 
     def make_server(self, i, readonly=False):
         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
@@ -353,6 +369,9 @@ class GridTestMixin:
                                num_servers=num_servers,
                                client_config_hooks=client_config_hooks)
         self.g.setServiceParent(self.s)
+        self._record_webports_and_baseurls()
+
+    def _record_webports_and_baseurls(self):
         self.client_webports = [c.getServiceNamed("webish").getPortnum()
                                 for c in self.g.clients]
         self.client_baseurls = [c.getServiceNamed("webish").getURL()
@@ -364,6 +383,16 @@ class GridTestMixin:
     def get_client(self, i=0):
         return self.g.clients[i]
 
+    def restart_client(self, i=0):
+        client = self.g.clients[i]
+        d = client.stopService()
+        def _make_client(ign):
+            c = self.g.make_client(i, write_config=False)
+            self.g.clients[i] = c
+            self._record_webports_and_baseurls()
+        d.addCallback(_make_client)
+        return d
+
     def get_serverdir(self, i):
         return self.g.servers_by_number[i].storedir
 
-- 
2.45.2