mutable: add peer-selection to Publish, and some basic unit tests
authorBrian Warner <warner@lothar.com>
Mon, 5 Nov 2007 07:38:07 +0000 (00:38 -0700)
committerBrian Warner <warner@lothar.com>
Mon, 5 Nov 2007 07:38:07 +0000 (00:38 -0700)
src/allmydata/mutable.py
src/allmydata/test/test_mutable.py

index d40b602d9013283f0bd183bc061a44655c3656ff..d4fa1cbfc0c6af047a666f1dbb69bc939f20e39a 100644 (file)
@@ -1,5 +1,5 @@
 
-import os, struct
+import os, struct, itertools
 from zope.interface import implements
 from twisted.internet import defer
 from allmydata.interfaces import IMutableFileNode, IMutableFileURI
@@ -7,6 +7,7 @@ from allmydata.util import hashutil, mathutil
 from allmydata.uri import WriteableSSKFileURI
 from allmydata.Crypto.Cipher import AES
 from allmydata import hashtree, codec
+from allmydata.encode import NotEnoughPeersError
 
 
 HEADER_LENGTH = struct.calcsize(">BQ32s BBQQ LLLLLQQ")
@@ -16,6 +17,8 @@ class NeedMoreDataError(Exception):
         Exception.__init__(self)
         self.needed_bytes = needed_bytes
 
+class UncoordinatedWriteError(Exception):
+    pass
 
 # use client.create_mutable_file() to make one of these
 
@@ -90,10 +93,7 @@ class MutableFileNode:
     def replace(self, newdata):
         return defer.succeed(None)
 
-class Retrieve:
-
-    def __init__(self, filenode):
-        self._node = filenode
+class ShareFormattingMixin:
 
     def _unpack_share(self, data):
         assert len(data) >= HEADER_LENGTH
@@ -140,9 +140,19 @@ class Retrieve:
                 pubkey, signature, share_hash_chain, block_hash_tree,
                 IV, share_data, enc_privkey)
 
+class Retrieve(ShareFormattingMixin):
+    def __init__(self, filenode):
+        self._node = filenode
+
+class DictOfSets(dict):
+    def add(self, key, value):
+        if key in self:
+            self[key].add(value)
+        else:
+            self[key] = set([value])
 
 
-class Publish:
+class Publish(ShareFormattingMixin):
     """I represent a single act of publishing the mutable file to the grid."""
 
     def __init__(self, filenode):
@@ -156,6 +166,9 @@ class Publish:
 
         # 1: generate shares (SDMF: files are small, so we can do it in RAM)
         # 2: perform peer selection, get candidate servers
+        #  2a: send queries to n+epsilon servers, to determine current shares
+        #  2b: based upon responses, create target map
+
         # 3: pre-allocate some shares to some servers, based upon any existing
         #    self._node._sharemap
         # 4: send allocate/testv_and_writev messages
@@ -178,7 +191,7 @@ class Publish:
         d.addCallback(self._generate_shares, old_seqnum+1,
                       privkey, self._encprivkey, pubkey)
 
-        d.addCallback(self._choose_peers_and_map_shares)
+        d.addCallback(self._query_peers, total_shares)
         d.addCallback(self._send_shares)
         d.addCallback(self._wait_for_responses)
         d.addCallback(lambda res: None)
@@ -332,7 +345,7 @@ class Publish:
                            offsets['enc_privkey'],
                            offsets['EOF'])
 
-    def _choose_peers_and_map_shares(self, (seqnum, root_hash, final_shares) ):
+    def _query_peers(self, (seqnum, root_hash, final_shares), total_shares):
         self._new_seqnum = seqnum
         self._new_root_hash = root_hash
         self._new_shares = final_shares
@@ -346,9 +359,89 @@ class Publish:
         # and the directory contents are unrecoverable, at least we can still
         # push out a new copy with brand-new contents.
 
-        new_sharemap = {}
+        current_share_peers = DictOfSets()
+        reachable_peers = {}
+
+        EPSILON = total_shares / 2
+        partial_peerlist = itertools.islice(peerlist, total_shares + EPSILON)
+        peer_storage_servers = {}
+        dl = []
+        for (permutedid, peerid, conn) in partial_peerlist:
+            d = self._do_query(conn, peerid, peer_storage_servers)
+            d.addCallback(self._got_query_results,
+                          peerid, permutedid,
+                          reachable_peers, current_share_peers)
+            dl.append(d)
+        d = defer.DeferredList(dl)
+        d.addCallback(self._got_all_query_results,
+                      total_shares, reachable_peers, seqnum,
+                      current_share_peers, peer_storage_servers)
+        # TODO: add an errback to, probably to ignore that peer
+        return d
 
-        # build the reverse sharehintmap
-        old_hints = {} # nodeid .. ?
-        for shnum, nodeids in self._node._sharemap:
-            pass
+    def _do_query(self, conn, peerid, peer_storage_servers):
+        d = conn.callRemote("get_service", "storageserver")
+        def _got_storageserver(ss):
+            peer_storage_servers[peerid] = ss
+            return ss.callRemote("readv_slots", [(0, 2000)])
+        d.addCallback(_got_storageserver)
+        return d
+
+    def _got_query_results(self, datavs, peerid, permutedid,
+                           reachable_peers, current_share_peers):
+        assert isinstance(datavs, dict)
+        reachable_peers[peerid] = permutedid
+        for shnum, datav in datavs.items():
+            assert len(datav) == 1
+            data = datav[0]
+            r = self._unpack_share(data)
+            share = (shnum, r[0], r[1]) # shnum,seqnum,R
+            current_share_peers[shnum].add( (peerid, r[0], r[1]) )
+
+    def _got_all_query_results(self, res,
+                               total_shares, reachable_peers, new_seqnum,
+                               current_share_peers, peer_storage_servers):
+        # now that we know everything about the shares currently out there,
+        # decide where to place the new shares.
+
+        # if an old share X is on a node, put the new share X there too.
+        # TODO: 1: redistribute shares to achieve one-per-peer, by copying
+        #       shares from existing peers to new (less-crowded) ones. The
+        #       old shares must still be updated.
+        # TODO: 2: move those shares instead of copying them, to reduce future
+        #       update work
+
+        shares_needing_homes = range(total_shares)
+        target_map = DictOfSets() # maps shnum to set((peerid,oldseqnum,oldR))
+        shares_per_peer = DictOfSets()
+        for shnum in range(total_shares):
+            for oldplace in current_share_peers.get(shnum, []):
+                (peerid, seqnum, R) = oldplace
+                if seqnum >= new_seqnum:
+                    raise UncoordinatedWriteError()
+                target_map[shnum].add(oldplace)
+                shares_per_peer.add(peerid, shnum)
+                if shnum in shares_needing_homes:
+                    shares_needing_homes.remove(shnum)
+
+        # now choose homes for the remaining shares. We prefer peers with the
+        # fewest target shares, then peers with the lowest permuted index. If
+        # there are no shares already in place, this will assign them
+        # one-per-peer in the normal permuted order.
+        while shares_needing_homes:
+            if not reachable_peers:
+                raise NotEnoughPeersError("ran out of peers during upload")
+            shnum = shares_needing_homes.pop(0)
+            possible_homes = reachable_peers.keys()
+            possible_homes.sort(lambda a,b:
+                                cmp( (len(shares_per_peer.get(a, [])),
+                                      reachable_peers[a]),
+                                     (len(shares_per_peer.get(b, [])),
+                                      reachable_peers[b]) ))
+            target_peerid = possible_homes[0]
+            target_map.add(shnum, (target_peerid, None, None) )
+            shares_per_peer.add(target_peerid, shnum)
+
+        assert not shares_needing_homes
+
+        return target_map
index c76fcb08e6c4854819dbfce94fef586e8d57d020..8a94c4b560849b1448ed978438068dbe6a5d4d11 100644 (file)
@@ -2,10 +2,15 @@
 import itertools, struct
 from twisted.trial import unittest
 from twisted.internet import defer
-
+from twisted.python import failure
 from allmydata import mutable, uri, dirnode2
 from allmydata.dirnode2 import split_netstring
-from allmydata.util.hashutil import netstring
+from allmydata.util import idlib
+from allmydata.util.hashutil import netstring, tagged_hash
+from allmydata.encode import NotEnoughPeersError
+
+import sha
+from allmydata.Crypto.Util.number import bytes_to_long
 
 class Netstring(unittest.TestCase):
     def test_split(self):
@@ -62,12 +67,21 @@ class FakeFilenode(mutable.MutableFileNode):
     def get_readonly(self):
         return "fake readonly"
 
+class FakePublish(mutable.Publish):
+    def _do_query(self, conn, peerid, peer_storage_servers):
+        assert conn[0] == peerid
+        shares = self._peers[peerid]
+        return defer.succeed(shares)
+
+
 class FakeNewDirectoryNode(dirnode2.NewDirectoryNode):
     filenode_class = FakeFilenode
 
 class MyClient:
-    def __init__(self):
-        pass
+    def __init__(self, num_peers=10):
+        self._num_peers = num_peers
+        self._peerids = [tagged_hash("peerid", "%d" % i)
+                         for i in range(self._num_peers)]
 
     def create_empty_dirnode(self):
         n = FakeNewDirectoryNode(self)
@@ -86,6 +100,18 @@ class MyClient:
     def create_mutable_file_from_uri(self, u):
         return FakeFilenode(self).init_from_uri(u)
 
+    def get_permuted_peers(self, key, include_myself=True):
+        """
+        @return: list of (permuted-peerid, peerid, connection,)
+        """
+        peers_and_connections = [(pid, (pid,)) for pid in self._peerids]
+        results = []
+        for peerid, connection in peers_and_connections:
+            assert isinstance(peerid, str)
+            permuted = bytes_to_long(sha.new(key + peerid).digest())
+            results.append((permuted, peerid, connection))
+        results.sort()
+        return results
 
 class Filenode(unittest.TestCase):
     def setUp(self):
@@ -204,6 +230,98 @@ class Publish(unittest.TestCase):
         d.addCallback(_done)
         return d
 
+    def setup_for_sharemap(self, num_peers):
+        c = MyClient(num_peers)
+        fn = FakeFilenode(c)
+        # .create usually returns a Deferred, but we happen to know it's
+        # synchronous
+        CONTENTS = "some initial contents"
+        fn.create(CONTENTS)
+        p = FakePublish(fn)
+        #r = mutable.Retrieve(fn)
+        p._peers = {}
+        for peerid in c._peerids:
+            p._peers[peerid] = {}
+        return c, p
+
+    def shouldFail(self, expected_failure, which, call, *args, **kwargs):
+        substring = kwargs.pop("substring", None)
+        d = defer.maybeDeferred(call, *args, **kwargs)
+        def _done(res):
+            if isinstance(res, failure.Failure):
+                res.trap(expected_failure)
+                if substring:
+                    self.failUnless(substring in str(res),
+                                    "substring '%s' not in '%s'"
+                                    % (substring, str(res)))
+            else:
+                self.fail("%s was supposed to raise %s, not get '%s'" %
+                          (which, expected_failure, res))
+        d.addBoth(_done)
+        return d
+
+    def test_sharemap_20newpeers(self):
+        c, p = self.setup_for_sharemap(20)
+
+        new_seqnum = 3
+        new_root_hash = "Rnew"
+        new_shares = None
+        total_shares = 10
+        d = p._query_peers( (new_seqnum, new_root_hash, new_seqnum),
+                            total_shares)
+        def _done(target_map):
+            shares_per_peer = {}
+            for shnum in target_map:
+                for (peerid, old_seqnum, old_R) in target_map[shnum]:
+                    #print "shnum[%d]: send to %s [oldseqnum=%s]" % \
+                    #      (shnum, idlib.b2a(peerid), old_seqnum)
+                    if peerid not in shares_per_peer:
+                        shares_per_peer[peerid] = 1
+                    else:
+                        shares_per_peer[peerid] += 1
+            # verify that we're sending only one share per peer
+            for peerid, count in shares_per_peer.items():
+                self.failUnlessEqual(count, 1)
+        d.addCallback(_done)
+        return d
+
+    def test_sharemap_3newpeers(self):
+        c, p = self.setup_for_sharemap(3)
+
+        new_seqnum = 3
+        new_root_hash = "Rnew"
+        new_shares = None
+        total_shares = 10
+        d = p._query_peers( (new_seqnum, new_root_hash, new_seqnum),
+                            total_shares)
+        def _done(target_map):
+            shares_per_peer = {}
+            for shnum in target_map:
+                for (peerid, old_seqnum, old_R) in target_map[shnum]:
+                    if peerid not in shares_per_peer:
+                        shares_per_peer[peerid] = 1
+                    else:
+                        shares_per_peer[peerid] += 1
+            # verify that we're sending 3 or 4 shares per peer
+            for peerid, count in shares_per_peer.items():
+                self.failUnless(count in (3,4), count)
+        d.addCallback(_done)
+        return d
+
+    def test_sharemap_nopeers(self):
+        c, p = self.setup_for_sharemap(0)
+
+        new_seqnum = 3
+        new_root_hash = "Rnew"
+        new_shares = None
+        total_shares = 10
+        d = self.shouldFail(NotEnoughPeersError, "test_sharemap_nopeers",
+                            p._query_peers,
+                            (new_seqnum, new_root_hash, new_seqnum),
+                            total_shares)
+        return d
+
+
 class FakePubKey:
     def serialize(self):
         return "PUBKEY"