From 09fd2dfb3aa0793eac454ad7c51990d239910255 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Wed, 7 Nov 2007 21:01:39 -0700
Subject: [PATCH] mutable: rearrange order of Publish to allow replace() to
 work. Doesn't work yet. Also test_mutable is disabled for a while.

---
 src/allmydata/mutable.py           | 336 ++++++++++++++++++-----------
 src/allmydata/test/test_mutable.py |   1 +
 src/allmydata/test/test_system.py  |  38 +++-
 3 files changed, 240 insertions(+), 135 deletions(-)

diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py
index d97e76fd..a3d93ab5 100644
--- a/src/allmydata/mutable.py
+++ b/src/allmydata/mutable.py
@@ -337,6 +337,7 @@ class Retrieve:
                     raise CorruptShareError(peerid,
                                             "pubkey doesn't match fingerprint")
                 self._pubkey = self._deserialize_pubkey(pubkey_s)
+                self._node._populate_pubkey(self._pubkey)
 
             verinfo = (seqnum, root_hash, IV)
             if verinfo not in self._valid_versions:
@@ -352,8 +353,10 @@ class Retrieve:
                 # and make a note of the other parameters we've just learned
                 if self._required_shares is None:
                     self._required_shares = k
+                    self._node._populate_required_shares(k)
                 if self._total_shares is None:
                     self._total_shares = N
+                    self._node._populate_total_shares(N)
                 if self._segsize is None:
                     self._segsize = segsize
                 if self._datalength is None:
@@ -494,7 +497,7 @@ class Retrieve:
         # all for the same seqnum+root_hash version, so it's now down to
         # doing FEC and decrypt.
         d = defer.maybeDeferred(self._decode, shares)
-        d.addCallback(self._decrypt, IV)
+        d.addCallback(self._decrypt, IV, seqnum, root_hash)
         return d
 
     def _validate_share_and_extract_data(self, root_hash, shnum, data):
@@ -559,10 +562,13 @@ class Retrieve:
         d.addErrback(_err)
         return d
 
-    def _decrypt(self, crypttext, IV):
+    def _decrypt(self, crypttext, IV, seqnum, root_hash):
         key = hashutil.ssk_readkey_data_hash(IV, self._readkey)
         decryptor = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16)
         plaintext = decryptor.decrypt(crypttext)
+        # it worked, so record the seqnum and root_hash for next time
+        self._node._populate_seqnum(seqnum)
+        self._node._populate_root_hash(root_hash)
         return plaintext
 
     def _done(self, contents):
@@ -608,29 +614,159 @@ class Publish:
 
         old_roothash = self._node._current_roothash
         old_seqnum = self._node._current_seqnum
+        assert old_seqnum is not None, "must read before replace"
+        self._new_seqnum = old_seqnum + 1
 
+        # read-before-replace also guarantees these fields are available
         readkey = self._node.get_readkey()
         required_shares = self._node.get_required_shares()
         total_shares = self._node.get_total_shares()
-        privkey = self._node.get_privkey()
-        encprivkey = self._node.get_encprivkey()
-        pubkey = self._node.get_pubkey()
+        self._pubkey = self._node.get_pubkey()
+
+        # these two may not be, we might have to get them from the first peer
+        self._privkey = self._node.get_privkey()
+        self._encprivkey = self._node.get_encprivkey()
 
         IV = os.urandom(16)
 
-        d = defer.succeed(newdata)
-        d.addCallback(self._encrypt_and_encode, readkey, IV,
+        d = defer.succeed(total_shares)
+        d.addCallback(self._query_peers)
+
+        d.addCallback(self._encrypt_and_encode, newdata, readkey, IV,
                       required_shares, total_shares)
-        d.addCallback(self._generate_shares, old_seqnum+1,
-                      privkey, encprivkey, pubkey)
+        d.addCallback(self._generate_shares, self._new_seqnum, IV)
 
-        d.addCallback(self._query_peers, total_shares)
         d.addCallback(self._send_shares, IV)
         d.addCallback(self._maybe_recover)
         d.addCallback(lambda res: None)
         return d
 
-    def _encrypt_and_encode(self, newdata, readkey, IV,
+    def _query_peers(self, total_shares):
+        self.log("_query_peers")
+
+        storage_index = self._node.get_storage_index()
+        peerlist = self._node._client.get_permuted_peers(storage_index,
+                                                         include_myself=False)
+        # we don't include ourselves in the N peers, but we *do* push an
+        # extra copy of share[0] to ourselves so we're more likely to have
+        # the signing key around later. This way, even if all the servers die
+        # and the directory contents are unrecoverable, at least we can still
+        # push out a new copy with brand-new contents.
+        # TODO: actually push this copy
+
+        current_share_peers = DictOfSets()
+        reachable_peers = {}
+
+        EPSILON = total_shares / 2
+        partial_peerlist = islice(peerlist, total_shares + EPSILON)
+        peer_storage_servers = {}
+        dl = []
+        for (permutedid, peerid, conn) in partial_peerlist:
+            d = self._do_query(conn, peerid, peer_storage_servers,
+                               storage_index)
+            d.addCallback(self._got_query_results,
+                          peerid, permutedid,
+                          reachable_peers, current_share_peers)
+            dl.append(d)
+        d = defer.DeferredList(dl)
+        d.addCallback(self._got_all_query_results,
+                      total_shares, reachable_peers,
+                      current_share_peers, peer_storage_servers)
+        # TODO: add an errback to, probably to ignore that peer
+        return d
+
+    def _do_query(self, conn, peerid, peer_storage_servers, storage_index):
+        d = conn.callRemote("get_service", "storageserver")
+        def _got_storageserver(ss):
+            peer_storage_servers[peerid] = ss
+            # TODO: only read 2KB, since all we really need is the seqnum
+            # info. But we need to read more from at least one peer so we can
+            # grab the encrypted privkey. Really, read just the 2k, and if
+            # the first response suggests that the privkey is beyond that
+            # segment, send out another query to the same peer for the
+            # privkey segment.
+            return ss.callRemote("slot_readv", storage_index, [], [(0, 2500)])
+        d.addCallback(_got_storageserver)
+        return d
+
+    def _got_query_results(self, datavs, peerid, permutedid,
+                           reachable_peers, current_share_peers):
+        self.log("_got_query_results")
+
+        assert isinstance(datavs, dict)
+        reachable_peers[peerid] = permutedid
+        for shnum, datav in datavs.items():
+            assert len(datav) == 1
+            data = datav[0]
+            r = unpack_share(data)
+            (seqnum, root_hash, IV, k, N, segsize, datalen,
+             pubkey, signature, share_hash_chain, block_hash_tree,
+             share_data, enc_privkey) = r
+            share = (shnum, seqnum, root_hash)
+            current_share_peers.add(shnum, (peerid, seqnum, root_hash) )
+            if not self._encprivkey:
+                self._encprivkey = enc_privkey
+                self._node._populate_encprivkey(self._encprivkey)
+            if not self._privkey:
+                privkey_s = self._node._decrypt_privkey(enc_privkey)
+                self._privkey = rsa.create_signing_key_from_string(privkey_s)
+                self._node._populate_privkey(self._privkey)
+            # TODO: make sure we actually fill these in before we try to
+            # upload. This means we may need to re-fetch something if our
+            # initial read was too short.
+
+    def _got_all_query_results(self, res,
+                               total_shares, reachable_peers,
+                               current_share_peers, peer_storage_servers):
+        self.log("_got_all_query_results")
+        # now that we know everything about the shares currently out there,
+        # decide where to place the new shares.
+
+        # if an old share X is on a node, put the new share X there too.
+        # TODO: 1: redistribute shares to achieve one-per-peer, by copying
+        #       shares from existing peers to new (less-crowded) ones. The
+        #       old shares must still be updated.
+        # TODO: 2: move those shares instead of copying them, to reduce future
+        #       update work
+
+        shares_needing_homes = range(total_shares)
+        target_map = DictOfSets() # maps shnum to set((peerid,oldseqnum,oldR))
+        shares_per_peer = DictOfSets()
+        for shnum in range(total_shares):
+            for oldplace in current_share_peers.get(shnum, []):
+                (peerid, seqnum, R) = oldplace
+                if seqnum >= self._new_seqnum:
+                    raise UncoordinatedWriteError()
+                target_map.add(shnum, oldplace)
+                shares_per_peer.add(peerid, shnum)
+                if shnum in shares_needing_homes:
+                    shares_needing_homes.remove(shnum)
+
+        # now choose homes for the remaining shares. We prefer peers with the
+        # fewest target shares, then peers with the lowest permuted index. If
+        # there are no shares already in place, this will assign them
+        # one-per-peer in the normal permuted order.
+        while shares_needing_homes:
+            if not reachable_peers:
+                raise NotEnoughPeersError("ran out of peers during upload")
+            shnum = shares_needing_homes.pop(0)
+            possible_homes = reachable_peers.keys()
+            possible_homes.sort(lambda a,b:
+                                cmp( (len(shares_per_peer.get(a, [])),
+                                      reachable_peers[a]),
+                                     (len(shares_per_peer.get(b, [])),
+                                      reachable_peers[b]) ))
+            target_peerid = possible_homes[0]
+            target_map.add(shnum, (target_peerid, None, None) )
+            shares_per_peer.add(target_peerid, shnum)
+
+        assert not shares_needing_homes
+
+        target_info = (target_map, peer_storage_servers)
+        return target_info
+
+    def _encrypt_and_encode(self, target_info,
+                            newdata, readkey, IV,
                             required_shares, total_shares):
         self.log("_encrypt_and_encode")
 
@@ -659,17 +795,25 @@ class Publish:
             assert len(piece) == piece_size
 
         d = fec.encode(crypttext_pieces)
-        d.addCallback(lambda shares:
-                      (shares, required_shares, total_shares,
-                       segment_size, len(crypttext), IV) )
+        d.addCallback(lambda shares_and_shareids:
+                      (shares_and_shareids,
+                       required_shares, total_shares,
+                       segment_size, len(crypttext),
+                       target_info) )
         return d
 
     def _generate_shares(self, (shares_and_shareids,
                                 required_shares, total_shares,
-                                segment_size, data_length, IV),
-                         seqnum, privkey, encprivkey, pubkey):
+                                segment_size, data_length,
+                                target_info),
+                         seqnum, IV):
         self.log("_generate_shares")
 
+        # we should know these by now
+        privkey = self._privkey
+        encprivkey = self._encprivkey
+        pubkey = self._pubkey
+
         (shares, share_ids) = shares_and_shareids
 
         assert len(shares) == len(share_ids)
@@ -737,118 +881,10 @@ class Publish:
                                            block_hash_tree_s,
                                            share_data,
                                            encprivkey])
-        return (seqnum, root_hash, final_shares)
-
+        return (seqnum, root_hash, final_shares, target_info)
 
-    def _query_peers(self, (seqnum, root_hash, final_shares), total_shares):
-        self.log("_query_peers")
-
-        self._new_seqnum = seqnum
-        self._new_root_hash = root_hash
-        self._new_shares = final_shares
 
-        storage_index = self._node.get_storage_index()
-        peerlist = self._node._client.get_permuted_peers(storage_index,
-                                                         include_myself=False)
-        # we don't include ourselves in the N peers, but we *do* push an
-        # extra copy of share[0] to ourselves so we're more likely to have
-        # the signing key around later. This way, even if all the servers die
-        # and the directory contents are unrecoverable, at least we can still
-        # push out a new copy with brand-new contents.
-        # TODO: actually push this copy
-
-        current_share_peers = DictOfSets()
-        reachable_peers = {}
-
-        EPSILON = total_shares / 2
-        partial_peerlist = islice(peerlist, total_shares + EPSILON)
-        peer_storage_servers = {}
-        dl = []
-        for (permutedid, peerid, conn) in partial_peerlist:
-            d = self._do_query(conn, peerid, peer_storage_servers,
-                               storage_index)
-            d.addCallback(self._got_query_results,
-                          peerid, permutedid,
-                          reachable_peers, current_share_peers)
-            dl.append(d)
-        d = defer.DeferredList(dl)
-        d.addCallback(self._got_all_query_results,
-                      total_shares, reachable_peers, seqnum,
-                      current_share_peers, peer_storage_servers)
-        # TODO: add an errback to, probably to ignore that peer
-        return d
-
-    def _do_query(self, conn, peerid, peer_storage_servers, storage_index):
-        d = conn.callRemote("get_service", "storageserver")
-        def _got_storageserver(ss):
-            peer_storage_servers[peerid] = ss
-            return ss.callRemote("slot_readv", storage_index, [], [(0, 2000)])
-        d.addCallback(_got_storageserver)
-        return d
-
-    def _got_query_results(self, datavs, peerid, permutedid,
-                           reachable_peers, current_share_peers):
-        self.log("_got_query_results")
-
-        assert isinstance(datavs, dict)
-        reachable_peers[peerid] = permutedid
-        for shnum, datav in datavs.items():
-            assert len(datav) == 1
-            data = datav[0]
-            r = unpack_share(data)
-            share = (shnum, r[0], r[1]) # shnum,seqnum,R
-            current_share_peers[shnum].add( (peerid, r[0], r[1]) )
-
-    def _got_all_query_results(self, res,
-                               total_shares, reachable_peers, new_seqnum,
-                               current_share_peers, peer_storage_servers):
-        self.log("_got_all_query_results")
-        # now that we know everything about the shares currently out there,
-        # decide where to place the new shares.
-
-        # if an old share X is on a node, put the new share X there too.
-        # TODO: 1: redistribute shares to achieve one-per-peer, by copying
-        #       shares from existing peers to new (less-crowded) ones. The
-        #       old shares must still be updated.
-        # TODO: 2: move those shares instead of copying them, to reduce future
-        #       update work
-
-        shares_needing_homes = range(total_shares)
-        target_map = DictOfSets() # maps shnum to set((peerid,oldseqnum,oldR))
-        shares_per_peer = DictOfSets()
-        for shnum in range(total_shares):
-            for oldplace in current_share_peers.get(shnum, []):
-                (peerid, seqnum, R) = oldplace
-                if seqnum >= new_seqnum:
-                    raise UncoordinatedWriteError()
-                target_map.add(shnum, oldplace)
-                shares_per_peer.add(peerid, shnum)
-                if shnum in shares_needing_homes:
-                    shares_needing_homes.remove(shnum)
-
-        # now choose homes for the remaining shares. We prefer peers with the
-        # fewest target shares, then peers with the lowest permuted index. If
-        # there are no shares already in place, this will assign them
-        # one-per-peer in the normal permuted order.
-        while shares_needing_homes:
-            if not reachable_peers:
-                raise NotEnoughPeersError("ran out of peers during upload")
-            shnum = shares_needing_homes.pop(0)
-            possible_homes = reachable_peers.keys()
-            possible_homes.sort(lambda a,b:
-                                cmp( (len(shares_per_peer.get(a, [])),
-                                      reachable_peers[a]),
-                                     (len(shares_per_peer.get(b, [])),
-                                      reachable_peers[b]) ))
-            target_peerid = possible_homes[0]
-            target_map.add(shnum, (target_peerid, None, None) )
-            shares_per_peer.add(target_peerid, shnum)
-
-        assert not shares_needing_homes
-
-        return (target_map, peer_storage_servers)
-
-    def _send_shares(self, (target_map, peer_storage_servers), IV ):
+    def _send_shares(self, (seqnum, root_hash, final_shares, target_info), IV):
         self.log("_send_shares")
         # we're finally ready to send out our shares. If we encounter any
         # surprises here, it's because somebody else is writing at the same
@@ -857,15 +893,16 @@ class Publish:
         # surprises here are *not* indications of UncoordinatedWriteError,
         # and we'll need to respond to them more gracefully.
 
-        my_checkstring = pack_checkstring(self._new_seqnum,
-                                          self._new_root_hash, IV)
+        target_map, peer_storage_servers = target_info
+
+        my_checkstring = pack_checkstring(seqnum, root_hash, IV)
         peer_messages = {}
         expected_old_shares = {}
 
         for shnum, peers in target_map.items():
             for (peerid, old_seqnum, old_root_hash) in peers:
                 testv = [(0, len(my_checkstring), "le", my_checkstring)]
-                new_share = self._new_shares[shnum]
+                new_share = final_shares[shnum]
                 writev = [(0, new_share)]
                 if peerid not in peer_messages:
                     peer_messages[peerid] = {}
@@ -982,6 +1019,14 @@ class MutableFileNode:
         self._readkey = self._uri.readkey
         self._storage_index = self._uri.storage_index
         self._fingerprint = self._uri.fingerprint
+        # the following values are learned during Retrieval
+        #  self._pubkey
+        #  self._required_shares
+        #  self._total_shares
+        # and these are needed for Publish. They are filled in by Retrieval
+        # if possible, otherwise by the first peer that Publish talks to.
+        self._privkey = None
+        self._encprivkey = None
         return self
 
     def create(self, initial_contents):
@@ -1028,6 +1073,34 @@ class MutableFileNode:
         crypttext = enc.encrypt(privkey)
         return crypttext
 
+    def _decrypt_privkey(self, enc_privkey):
+        enc = AES.new(key=self._writekey, mode=AES.MODE_CTR, counterstart="\x00"*16)
+        privkey = enc.decrypt(enc_privkey)
+        return privkey
+
+    def _populate(self, stuff):
+        # the Retrieval object calls this with values it discovers when
+        # downloading the slot. This is how a MutableFileNode that was
+        # created from a URI learns about its full key.
+        pass
+
+    def _populate_pubkey(self, pubkey):
+        self._pubkey = pubkey
+    def _populate_required_shares(self, required_shares):
+        self._required_shares = required_shares
+    def _populate_total_shares(self, total_shares):
+        self._total_shares = total_shares
+    def _populate_seqnum(self, seqnum):
+        self._current_seqnum = seqnum
+    def _populate_root_hash(self, root_hash):
+        self._current_roothash = root_hash
+
+    def _populate_privkey(self, privkey):
+        self._privkey = privkey
+    def _populate_encprivkey(self, encprivkey):
+        self._encprivkey = encprivkey
+
+
     def get_write_enabler(self, peerid):
         assert len(peerid) == 20
         return hashutil.ssk_write_enabler_hash(self._writekey, peerid)
@@ -1093,4 +1166,7 @@ class MutableFileNode:
         return r.retrieve()
 
     def replace(self, newdata):
-        return defer.succeed(None)
+        r = Retrieve(self)
+        d = r.retrieve()
+        d.addCallback(lambda res: self._publish(newdata))
+        return d
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index 25b3e28f..42a01a66 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -387,6 +387,7 @@ class Publish(unittest.TestCase):
         d.addCallback(_done)
         return d
 
+del Publish # gotta run, will fix this in a few hours
 
 class FakePubKey:
     def __init__(self, count):
diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py
index 07ac4ee2..b784bc68 100644
--- a/src/allmydata/test/test_system.py
+++ b/src/allmydata/test/test_system.py
@@ -242,6 +242,9 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
     def test_mutable(self):
         self.basedir = "system/SystemTest/test_mutable"
         DATA = "initial contents go here."  # 25 bytes % 3 != 0
+        NEWDATA = "new contents yay"
+        NEWERDATA = "this is getting old"
+
         d = self.set_up_nodes()
 
         def _create_mutable(res):
@@ -255,7 +258,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
                 self._mutable_node_1 = res
                 uri = res.get_uri()
                 #print "DONE", uri
-            d1.addBoth(_done)
+            d1.addCallback(_done)
             return d1
         d.addCallback(_create_mutable)
 
@@ -299,11 +302,11 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
                 m = re.search(r'^ container_size: (\d+)$', output, re.M)
                 self.failUnless(m)
                 container_size = int(m.group(1))
-                self.failUnless(2046 <= container_size <= 2049)
+                self.failUnless(2046 <= container_size <= 2049, container_size)
                 m = re.search(r'^ data_length: (\d+)$', output, re.M)
                 self.failUnless(m)
                 data_length = int(m.group(1))
-                self.failUnless(2046 <= data_length <= 2049)
+                self.failUnless(2046 <= data_length <= 2049, data_length)
                 self.failUnless("  secrets are for nodeid: %s\n" % peerid
                                 in output)
                 self.failUnless(" SDMF contents:\n" in output)
@@ -351,14 +354,39 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
             #print "starting download 3"
             uri = self._mutable_node_1.get_uri()
             newnode = self.clients[1].create_mutable_file_from_uri(uri)
-            return newnode.download_to_data()
+            d1 = newnode.download_to_data()
+            d1.addCallback(lambda res: (res, newnode))
+            return d1
         d.addCallback(_check_download_2)
 
-        def _check_download_3(res):
+        def _check_download_3((res, newnode)):
             #print "_check_download_3"
             self.failUnlessEqual(res, DATA)
+            # replace the data
+            #print "REPLACING"
+            d1 = newnode.replace(NEWDATA)
+            d1.addCallback(lambda res: newnode.download_to_data())
+            return d1
         d.addCallback(_check_download_3)
 
+        def _check_download_4(res):
+            print "_check_download_4"
+            self.failUnlessEqual(res, NEWDATA)
+            # now create an even newer node and replace the data on it. This
+            # new node has never been used for download before.
+            uri = self._mutable_node_1.get_uri()
+            newnode1 = self.clients[2].create_mutable_file_from_uri(uri)
+            newnode2 = self.clients[3].create_mutable_file_from_uri(uri)
+            d1 = newnode1.replace(NEWERDATA)
+            d1.addCallback(lambda res: newnode2.download_to_data())
+            return d1
+        #d.addCallback(_check_download_4)
+
+        def _check_download_5(res):
+            print "_check_download_5"
+            self.failUnlessEqual(res, NEWERDATA)
+        #d.addCallback(_check_download_5)
+
         return d
 
     def flip_bit(self, good):
-- 
2.45.2