]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/test/no_network.py
Additional tests for MDMF URIs and for zero-length files. refs #393
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
index 554a75f6c7151df47fe7f76040181dab730b77b4..5654c2cedfe7da723f2d4a10ac771d9cd199a71e 100644 (file)
@@ -27,6 +27,8 @@ from allmydata.util import fileutil, idlib, hashutil
 from allmydata.util.hashutil import sha1
 from allmydata.test.common_web import HTTPClientGETFactory
 from allmydata.interfaces import IStorageBroker
+from allmydata.test.common import TEST_RSA_KEY_SIZE
+
 
 class IntentionalError(Exception):
     pass
@@ -117,13 +119,38 @@ def wrap_storage_server(original):
     wrapper.version = original.remote_get_version()
     return wrapper
 
+class NoNetworkServer:
+    def __init__(self, serverid, rref):
+        self.serverid = serverid
+        self.rref = rref
+    def __repr__(self):
+        return "<NoNetworkServer for %s>" % self.get_name()
+    def get_serverid(self):
+        return self.serverid
+    def get_permutation_seed(self):
+        return self.serverid
+    def get_lease_seed(self):
+        return self.serverid
+    def get_name(self):
+        return idlib.shortnodeid_b2a(self.serverid)
+    def get_longname(self):
+        return idlib.nodeid_b2a(self.serverid)
+    def get_nickname(self):
+        return "nickname"
+    def get_rref(self):
+        return self.rref
+    def get_version(self):
+        return self.rref.version
+
 class NoNetworkStorageBroker:
     implements(IStorageBroker)
-    def get_servers_for_index(self, key):
-        return sorted(self.client._servers,
-                      key=lambda x: sha1(key+x[0]).digest())
-    def get_all_servers(self):
-        return frozenset(self.client._servers)
+    def get_servers_for_psi(self, peer_selection_index):
+        def _permuted(server):
+            seed = server.get_permutation_seed()
+            return sha1(peer_selection_index + seed).digest()
+        return sorted(self.get_connected_servers(), key=_permuted)
+    def get_connected_servers(self):
+        return self.client._servers
     def get_nickname_for_serverid(self, serverid):
         return None
 
@@ -181,8 +208,10 @@ class NoNetworkGrid(service.MultiService):
         self.basedir = basedir
         fileutil.make_dirs(basedir)
 
-        self.servers_by_number = {}
-        self.servers_by_id = {}
+        self.servers_by_number = {} # maps to StorageServer instance
+        self.wrappers_by_id = {} # maps to wrapped StorageServer instance
+        self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
+                                # StorageServer
         self.clients = []
 
         for i in range(num_servers):
@@ -209,7 +238,7 @@ class NoNetworkGrid(service.MultiService):
                 c = client_config_hooks[i](clientdir)
             if not c:
                 c = NoNetworkClient(clientdir)
-                c.set_default_mutable_keysize(522)
+                c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
             c.nodeid = clientid
             c.short_nodeid = b32encode(clientid).lower()[:8]
             c._servers = self.all_servers # can be updated later
@@ -234,11 +263,16 @@ class NoNetworkGrid(service.MultiService):
         ss.setServiceParent(middleman)
         serverid = ss.my_nodeid
         self.servers_by_number[i] = ss
-        self.servers_by_id[serverid] = wrap_storage_server(ss)
+        wrapper = wrap_storage_server(ss)
+        self.wrappers_by_id[serverid] = wrapper
+        self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
         self.rebuild_serverlist()
 
+    def get_all_serverids(self):
+        return self.proxies_by_id.keys()
+
     def rebuild_serverlist(self):
-        self.all_servers = frozenset(self.servers_by_id.items())
+        self.all_servers = frozenset(self.proxies_by_id.values())
         for c in self.clients:
             c._servers = self.all_servers
 
@@ -249,23 +283,24 @@ class NoNetworkGrid(service.MultiService):
             if ss.my_nodeid == serverid:
                 del self.servers_by_number[i]
                 break
-        del self.servers_by_id[serverid]
+        del self.wrappers_by_id[serverid]
+        del self.proxies_by_id[serverid]
         self.rebuild_serverlist()
 
     def break_server(self, serverid):
         # mark the given server as broken, so it will throw exceptions when
         # asked to hold a share or serve a share
-        self.servers_by_id[serverid].broken = True
+        self.wrappers_by_id[serverid].broken = True
 
     def hang_server(self, serverid):
         # hang the given server
-        ss = self.servers_by_id[serverid]
+        ss = self.wrappers_by_id[serverid]
         assert ss.hung_until is None
         ss.hung_until = defer.Deferred()
 
     def unhang_server(self, serverid):
         # unhang the given server
-        ss = self.servers_by_id[serverid]
+        ss = self.wrappers_by_id[serverid]
         assert ss.hung_until is not None
         ss.hung_until.callback(None)
         ss.hung_until = None
@@ -371,3 +406,6 @@ class GridTestMixin:
         if return_response:
             d.addCallback(_got_data)
         return factory.deferred
+
+    def PUT(self, urlpath, **kwargs):
+        return self.GET(urlpath, method="PUT", **kwargs)