From 711c09bc5d17f10adb87008480fcba036fc13d1a Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Sun, 21 Jun 2009 16:51:19 -0700
Subject: [PATCH] clean up storage_broker interface: should fix #732

---
 src/allmydata/client.py              |  5 -----
 src/allmydata/immutable/download.py  |  6 ++++--
 src/allmydata/immutable/filenode.py  | 10 +++++++---
 src/allmydata/immutable/offloaded.py |  2 +-
 src/allmydata/immutable/repairer.py  |  3 ++-
 src/allmydata/immutable/upload.py    |  2 +-
 src/allmydata/interfaces.py          | 17 +++++++++++++++++
 src/allmydata/mutable/publish.py     |  2 +-
 src/allmydata/mutable/servermap.py   |  2 +-
 src/allmydata/storage_client.py      | 20 +++++++++++++-------
 src/allmydata/test/no_network.py     | 11 +++++++----
 src/allmydata/test/test_client.py    |  2 +-
 src/allmydata/test/test_encode.py    | 12 ++++++------
 src/allmydata/test/test_mutable.py   |  2 +-
 src/allmydata/test/test_system.py    |  4 ++--
 src/allmydata/web/check_results.py   |  2 +-
 16 files changed, 65 insertions(+), 37 deletions(-)

diff --git a/src/allmydata/client.py b/src/allmydata/client.py
index 0a08cddd..037489ad 100644
--- a/src/allmydata/client.py
+++ b/src/allmydata/client.py
@@ -326,11 +326,6 @@ class Client(node.Node, pollmixin.PollMixin):
     def _lost_key_generator(self):
         self._key_generator = None
 
-    def get_servers(self, service_name):
-        """ Return frozenset of (peerid, versioned-rref) """
-        assert isinstance(service_name, str)
-        return self.introducer_client.get_peers(service_name)
-
     def init_web(self, webport):
         self.log("init_web(webport=%s)", args=(webport,))
 
diff --git a/src/allmydata/immutable/download.py b/src/allmydata/immutable/download.py
index 1882a75e..acc03add 100644
--- a/src/allmydata/immutable/download.py
+++ b/src/allmydata/immutable/download.py
@@ -9,7 +9,8 @@ from allmydata.util import base32, deferredutil, hashutil, log, mathutil, idlib
 from allmydata.util.assertutil import _assert, precondition
 from allmydata import codec, hashtree, uri
 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, IVerifierURI, \
-     IDownloadStatus, IDownloadResults, IValidatedThingProxy, NotEnoughSharesError, \
+     IDownloadStatus, IDownloadResults, IValidatedThingProxy, \
+     IStorageBroker, NotEnoughSharesError, \
      UnableToFetchCriticalDownloadDataError
 from allmydata.immutable import layout
 from allmydata.monitor import Monitor
@@ -626,6 +627,7 @@ class CiphertextDownloader(log.PrefixingLogMixin):
 
     def __init__(self, storage_broker, v, target, monitor):
 
+        precondition(IStorageBroker.providedBy(storage_broker), storage_broker)
         precondition(IVerifierURI.providedBy(v), v)
         precondition(IDownloadTarget.providedBy(target), target)
 
@@ -745,7 +747,7 @@ class CiphertextDownloader(log.PrefixingLogMixin):
     def _get_all_shareholders(self):
         dl = []
         sb = self._storage_broker
-        for (peerid,ss) in sb.get_servers(self._storage_index):
+        for (peerid,ss) in sb.get_servers_for_index(self._storage_index):
             self.log(format="sending DYHB to [%(peerid)s]",
                      peerid=idlib.shortnodeid_b2a(peerid),
                      level=log.NOISY, umid="rT03hg")
diff --git a/src/allmydata/immutable/filenode.py b/src/allmydata/immutable/filenode.py
index bace4357..7ff2aaca 100644
--- a/src/allmydata/immutable/filenode.py
+++ b/src/allmydata/immutable/filenode.py
@@ -201,7 +201,8 @@ class FileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
 
     def check_and_repair(self, monitor, verify=False, add_lease=False):
         verifycap = self.get_verify_cap()
-        servers = self._client.get_servers("storage")
+        sb = self._client.get_storage_broker()
+        servers = sb.get_all_servers()
 
         c = Checker(client=self._client, verifycap=verifycap, servers=servers,
                     verify=verify, add_lease=add_lease, monitor=monitor)
@@ -253,8 +254,11 @@ class FileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
         return d
 
     def check(self, monitor, verify=False, add_lease=False):
-        v = Checker(client=self._client, verifycap=self.get_verify_cap(),
-                    servers=self._client.get_servers("storage"),
+        verifycap = self.get_verify_cap()
+        sb = self._client.get_storage_broker()
+        servers = sb.get_all_servers()
+
+        v = Checker(client=self._client, verifycap=verifycap, servers=servers,
                     verify=verify, add_lease=add_lease, monitor=monitor)
         return v.start()
 
diff --git a/src/allmydata/immutable/offloaded.py b/src/allmydata/immutable/offloaded.py
index a71bf132..88c3099f 100644
--- a/src/allmydata/immutable/offloaded.py
+++ b/src/allmydata/immutable/offloaded.py
@@ -619,7 +619,7 @@ class Helper(Referenceable, service.MultiService):
         lp2 = self.log("doing a quick check+UEBfetch",
                        parent=lp, level=log.NOISY)
         sb = self.parent.get_storage_broker()
-        c = CHKCheckerAndUEBFetcher(sb.get_servers, storage_index, lp2)
+        c = CHKCheckerAndUEBFetcher(sb.get_servers_for_index, storage_index, lp2)
         d = c.check()
         def _checked(res):
             if res:
diff --git a/src/allmydata/immutable/repairer.py b/src/allmydata/immutable/repairer.py
index 84118a48..a02e8adb 100644
--- a/src/allmydata/immutable/repairer.py
+++ b/src/allmydata/immutable/repairer.py
@@ -50,7 +50,8 @@ class Repairer(log.PrefixingLogMixin):
     def start(self):
         self.log("starting repair")
         duc = DownUpConnector()
-        dl = download.CiphertextDownloader(self._client, self._verifycap, target=duc, monitor=self._monitor)
+        sb = self._client.get_storage_broker()
+        dl = download.CiphertextDownloader(sb, self._verifycap, target=duc, monitor=self._monitor)
         ul = upload.CHKUploader(self._client)
 
         d = defer.Deferred()
diff --git a/src/allmydata/immutable/upload.py b/src/allmydata/immutable/upload.py
index 9a0bce97..0dc541db 100644
--- a/src/allmydata/immutable/upload.py
+++ b/src/allmydata/immutable/upload.py
@@ -167,7 +167,7 @@ class Tahoe2PeerSelector:
         self.preexisting_shares = {} # sharenum -> peerid holding the share
 
         sb = client.get_storage_broker()
-        peers = list(sb.get_servers(storage_index))
+        peers = sb.get_servers_for_index(storage_index)
         if not peers:
             raise NoServersError("client gave us zero peers")
 
diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py
index 13b96e7e..24d67a50 100644
--- a/src/allmydata/interfaces.py
+++ b/src/allmydata/interfaces.py
@@ -349,6 +349,23 @@ class IStorageBucketReader(Interface):
         @return: URIExtensionData
         """
 
+class IStorageBroker(Interface):
+    def get_servers_for_index(peer_selection_index):
+        """
+        @return: list of (peerid, versioned-rref) tuples
+        """
+    def get_all_servers():
+        """
+        @return: frozenset of (peerid, versioned-rref) tuples
+        """
+    def get_all_serverids():
+        """
+        @return: iterator of serverid strings
+        """
+    def get_nickname_for_serverid(serverid):
+        """
+        @return: unicode nickname, or None
+        """
 
 
 # hm, we need a solution for forward references in schemas
diff --git a/src/allmydata/mutable/publish.py b/src/allmydata/mutable/publish.py
index 60de09ee..942d87c2 100644
--- a/src/allmydata/mutable/publish.py
+++ b/src/allmydata/mutable/publish.py
@@ -177,7 +177,7 @@ class Publish:
         self._encprivkey = self._node.get_encprivkey()
 
         sb = self._node._client.get_storage_broker()
-        full_peerlist = sb.get_servers(self._storage_index)
+        full_peerlist = sb.get_servers_for_index(self._storage_index)
         self.full_peerlist = full_peerlist # for use later, immutable
         self.bad_peers = set() # peerids who have errbacked/refused requests
 
diff --git a/src/allmydata/mutable/servermap.py b/src/allmydata/mutable/servermap.py
index 9c598580..0efa37ba 100644
--- a/src/allmydata/mutable/servermap.py
+++ b/src/allmydata/mutable/servermap.py
@@ -422,7 +422,7 @@ class ServermapUpdater:
         self._queries_completed = 0
 
         sb = self._node._client.get_storage_broker()
-        full_peerlist = list(sb.get_servers(self._node._storage_index))
+        full_peerlist = sb.get_servers_for_index(self._node._storage_index)
         self.full_peerlist = full_peerlist # for use later, immutable
         self.extra_peers = full_peerlist[:] # peers are removed as we use them
         self._good_peers = set() # peers who had some shares
diff --git a/src/allmydata/storage_client.py b/src/allmydata/storage_client.py
index e8050379..eb4a3733 100644
--- a/src/allmydata/storage_client.py
+++ b/src/allmydata/storage_client.py
@@ -19,8 +19,11 @@ the foolscap-based server implemented in src/allmydata/storage/*.py .
 #  implement tahoe.cfg scanner, create static NativeStorageClients
 
 import sha
+from zope.interface import implements
+from allmydata.interfaces import IStorageBroker
 
 class StorageFarmBroker:
+    implements(IStorageBroker)
     """I live on the client, and know about storage servers. For each server
     that is participating in a grid, I either maintain a connection to it or
     remember enough information to establish a connection to it on demand.
@@ -38,20 +41,23 @@ class StorageFarmBroker:
         self.introducer_client = ic = introducer_client
         ic.subscribe_to("storage")
 
-    def get_servers(self, peer_selection_index):
-        # first cut: return an iterator of (peerid, versioned-rref) tuples
+    def get_servers_for_index(self, peer_selection_index):
+        # first cut: return a list of (peerid, versioned-rref) tuples
         assert self.permute_peers == True
+        servers = self.get_all_servers()
+        key = peer_selection_index
+        return sorted(servers, key=lambda x: sha.new(key+x[0]).digest())
+
+    def get_all_servers(self):
+        # return a frozenset of (peerid, versioned-rref) tuples
         servers = {}
         for serverid,server in self.servers.items():
             servers[serverid] = server
         if self.introducer_client:
             ic = self.introducer_client
-            for serverid,server in ic.get_permuted_peers("storage",
-                                                         peer_selection_index):
+            for serverid,server in ic.get_peers("storage"):
                 servers[serverid] = server
-        servers = servers.items()
-        key = peer_selection_index
-        return sorted(servers, key=lambda x: sha.new(key+x[0]).digest())
+        return frozenset(servers.items())
 
     def get_all_serverids(self):
         for serverid in self.servers:
diff --git a/src/allmydata/test/no_network.py b/src/allmydata/test/no_network.py
index d5d3b48c..8af6d1b8 100644
--- a/src/allmydata/test/no_network.py
+++ b/src/allmydata/test/no_network.py
@@ -15,6 +15,7 @@
 
 import os.path
 import sha
+from zope.interface import implements
 from twisted.application import service
 from twisted.internet import reactor
 from twisted.python.failure import Failure
@@ -26,6 +27,7 @@ from allmydata.storage.server import StorageServer, storage_index_to_dir
 from allmydata.util import fileutil, idlib, hashutil
 from allmydata.introducer.client import RemoteServiceConnector
 from allmydata.test.common_web import HTTPClientGETFactory
+from allmydata.interfaces import IStorageBroker
 
 class IntentionalError(Exception):
     pass
@@ -105,9 +107,12 @@ def wrap(original, service_name):
     return wrapper
 
 class NoNetworkStorageBroker:
-    def get_servers(self, key):
+    implements(IStorageBroker)
+    def get_servers_for_index(self, key):
         return sorted(self.client._servers,
                       key=lambda x: sha.new(key+x[0]).digest())
+    def get_all_servers(self):
+        return frozenset(self.client._servers)
     def get_nickname_for_serverid(self, serverid):
         return None
 
@@ -138,9 +143,7 @@ class NoNetworkClient(Client):
         self.storage_broker.client = self
     def init_stub_client(self):
         pass
-
-    def get_servers(self, service_name):
-        return self._servers
+    #._servers will be set by the NoNetworkGrid which creates us
 
 class SimpleStats:
     def __init__(self):
diff --git a/src/allmydata/test/test_client.py b/src/allmydata/test/test_client.py
index 06077a5c..63f4962f 100644
--- a/src/allmydata/test/test_client.py
+++ b/src/allmydata/test/test_client.py
@@ -143,7 +143,7 @@ class Basic(unittest.TestCase):
 
     def _permute(self, sb, key):
         return [ peerid
-                 for (peerid,rref) in sb.get_servers(key) ]
+                 for (peerid,rref) in sb.get_servers_for_index(key) ]
 
     def test_permute(self):
         sb = StorageFarmBroker()
diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py
index 6e8ba069..c6a7bda1 100644
--- a/src/allmydata/test/test_encode.py
+++ b/src/allmydata/test/test_encode.py
@@ -8,7 +8,8 @@ from allmydata import hashtree, uri
 from allmydata.immutable import encode, upload, download
 from allmydata.util import hashutil
 from allmydata.util.assertutil import _assert
-from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader, NotEnoughSharesError
+from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader, \
+     NotEnoughSharesError, IStorageBroker
 from allmydata.monitor import Monitor
 import common_util as testutil
 
@@ -18,9 +19,8 @@ class LostPeerError(Exception):
 def flip_bit(good): # flips the last bit
     return good[:-1] + chr(ord(good[-1]) ^ 0x01)
 
-class FakeClient:
-    def log(self, *args, **kwargs):
-        pass
+class FakeStorageBroker:
+    implements(IStorageBroker)
 
 class FakeBucketReaderWriterProxy:
     implements(IStorageBucketWriter, IStorageBucketReader)
@@ -494,11 +494,11 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
                            total_shares=verifycap.total_shares,
                            size=verifycap.size)
 
-        client = FakeClient()
+        sb = FakeStorageBroker()
         if not target:
             target = download.Data()
         target = download.DecryptingTarget(target, u.key)
-        fd = download.CiphertextDownloader(client, u.get_verify_cap(), target, monitor=Monitor())
+        fd = download.CiphertextDownloader(sb, u.get_verify_cap(), target, monitor=Monitor())
 
         # we manually cycle the CiphertextDownloader through a number of steps that
         # would normally be sequenced by a Deferred chain in
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index 0862bce6..23bc6761 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -1914,7 +1914,7 @@ class Problems(unittest.TestCase, testutil.ShouldFailMixin):
         d.addCallback(n._generated)
         def _break_peer0(res):
             si = n.get_storage_index()
-            peerlist = list(self.client.storage_broker.get_servers(si))
+            peerlist = self.client.storage_broker.get_servers_for_index(si)
             peerid0, connection0 = peerlist[0]
             peerid1, connection1 = peerlist[1]
             connection0.broken = True
diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py
index 1e9abb23..861a2007 100644
--- a/src/allmydata/test/test_system.py
+++ b/src/allmydata/test/test_system.py
@@ -76,7 +76,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
                 all_peerids = list(c.get_storage_broker().get_all_serverids())
                 self.failUnlessEqual(len(all_peerids), self.numclients+1)
                 sb = c.storage_broker
-                permuted_peers = list(sb.get_servers("a"))
+                permuted_peers = list(sb.get_servers_for_index("a"))
                 self.failUnlessEqual(len(permuted_peers), self.numclients+1)
 
         d.addCallback(_check)
@@ -111,7 +111,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
                 all_peerids = list(c.get_storage_broker().get_all_serverids())
                 self.failUnlessEqual(len(all_peerids), self.numclients)
                 sb = c.storage_broker
-                permuted_peers = list(sb.get_servers("a"))
+                permuted_peers = list(sb.get_servers_for_index("a"))
                 self.failUnlessEqual(len(permuted_peers), self.numclients)
         d.addCallback(_check_connections)
 
diff --git a/src/allmydata/web/check_results.py b/src/allmydata/web/check_results.py
index 94fb7f61..0f12f66e 100644
--- a/src/allmydata/web/check_results.py
+++ b/src/allmydata/web/check_results.py
@@ -141,7 +141,7 @@ class ResultsBase:
         sb = c.get_storage_broker()
         permuted_peer_ids = [peerid
                              for (peerid, rref)
-                             in sb.get_servers(cr.get_storage_index())]
+                             in sb.get_servers_for_index(cr.get_storage_index())]
 
         num_shares_left = sum([len(shares) for shares in servers.values()])
         servermap = []
-- 
2.45.2