From be94960680dcd260ac6de8f29b452d1fd44b5198 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Wed, 7 Nov 2007 14:19:01 -0700
Subject: [PATCH] mutable: test roundtrip, make it work

---
 src/allmydata/interfaces.py       |   2 +-
 src/allmydata/mutable.py          | 185 ++++++++++++++++++++++--------
 src/allmydata/test/test_system.py |  43 ++++++-
 3 files changed, 177 insertions(+), 53 deletions(-)

diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py
index 171d2da4..37aad178 100644
--- a/src/allmydata/interfaces.py
+++ b/src/allmydata/interfaces.py
@@ -146,7 +146,7 @@ class RIStorageServer(RemoteInterface):
         """Read a vector from the numbered shares associated with the given
         storage index. An empty shares list means to return data from all
         known shares. Returns a dictionary with one key per share."""
-        return DictOf(int, DataVector) # shnum -> results
+        return DictOf(int, ReadData) # shnum -> results
 
     def slot_testv_and_readv_and_writev(storage_index=StorageIndex,
                                         secrets=TupleOf(Hash, Hash, Hash),
diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py
index 7ced3b1e..dd063252 100644
--- a/src/allmydata/mutable.py
+++ b/src/allmydata/mutable.py
@@ -166,9 +166,11 @@ class Retrieve:
         self._pubkey = filenode.get_pubkey()
         self._storage_index = filenode.get_storage_index()
         self._readkey = filenode.get_readkey()
+        self._last_failure = None
 
     def log(self, msg):
-        self._node._client.log(msg)
+        #self._node._client.log(msg)
+        pass
 
     def retrieve(self):
         """Retrieve the filenode's current contents. Returns a Deferred that
@@ -218,12 +220,37 @@ class Retrieve:
         self._segsize = None
         self._datalength = None
 
+        # self._valid_versions is a dictionary in which the keys are
+        # 'verinfo' tuples (seqnum, root_hash, IV). Every time we hear about
+        # a new potential version of the file, we check its signature, and
+        # the valid ones are added to this dictionary. The values of the
+        # dictionary are (prefix, sharemap) tuples, where 'prefix' is just
+        # the first part of the share (containing the serialized verinfo),
+        # for easier comparison. 'sharemap' is a DictOfSets, in which the
+        # keys are sharenumbers, and the values are sets of (peerid, data)
+        # tuples. There is a (peerid, data) tuple for every instance of a
+        # given share that we've seen. The 'data' in this tuple is a full
+        # copy of the SDMF share, starting with the \x00 version byte and
+        # continuing through the last byte of sharedata.
+        self._valid_versions = {}
+
+        # self._valid_shares is a set (peerid,data) tuples. Each time we
+        # examine the hash chains inside a share and validate them against a
+        # signed root_hash, we add the share to self._valid_shares . We use
+        # this to avoid re-checking the hashes over and over again.
+        self._valid_shares = set()
+
+        self._done_deferred = defer.Deferred()
+
         d = defer.succeed(initial_query_count)
         d.addCallback(self._choose_initial_peers)
         d.addCallback(self._send_initial_requests)
-        d.addCallback(lambda res: self._contents)
+        d.addCallback(self._wait_for_finish)
         return d
 
+    def _wait_for_finish(self, res):
+        return self._done_deferred
+
     def _choose_initial_peers(self, numqueries):
         n = self._node
         full_peerlist = n._client.get_permuted_peers(self._storage_index,
@@ -246,10 +273,11 @@ class Retrieve:
         self._bad_peerids = set()
         self._running = True
         self._queries_outstanding = set()
+        self._used_peers = set()
         self._sharemap = DictOfSets() # shnum -> [(peerid, seqnum, R)..]
         self._peer_storage_servers = {}
         dl = []
-        for (permutedid, peerid, conn) in peerlist:
+        for (peerid, conn) in peerlist:
             self._queries_outstanding.add(peerid)
             self._do_query(conn, peerid, self._storage_index, self._read_size,
                            self._peer_storage_servers)
@@ -257,9 +285,7 @@ class Retrieve:
         # control flow beyond this point: state machine. Receiving responses
         # from queries is the input. We might send out more queries, or we
         # might produce a result.
-
-        d = self._done_deferred = defer.Deferred()
-        return d
+        return None
 
     def _do_query(self, conn, peerid, storage_index, readsize,
                   peer_storage_servers):
@@ -281,13 +307,10 @@ class Retrieve:
 
     def _deserialize_pubkey(self, pubkey_s):
         # TODO
+        from allmydata.test.test_mutable import FakePubKey
+        return FakePubKey(0)
         return None
 
-    def _validate_share(self, root_hash, shnum, data):
-        if False:
-            raise CorruptShareError("explanation")
-        pass
-
     def _got_results(self, datavs, peerid, readsize):
         self._queries_outstanding.discard(peerid)
         self._used_peers.add(peerid)
@@ -350,8 +373,9 @@ class Retrieve:
             self._do_query(conn, peerid, storage_index, self._read_size,
                            peer_storage_servers)
             return
+        self._last_failure = f
         self._bad_peerids.add(peerid)
-        short_sid = idlib.a2b(self.storage_index)[:6]
+        short_sid = idlib.b2a(self._storage_index)[:6]
         if f.check(CorruptShareError):
             self.log("WEIRD: bad share for %s: %s" % (short_sid, f))
         else:
@@ -362,13 +386,17 @@ class Retrieve:
         share_prefixes = {}
         versionmap = DictOfSets()
         for verinfo, (prefix, sharemap) in self._valid_versions.items():
+            # sharemap is a dict that maps shnums to sets of (peerid,data).
+            # len(sharemap) is the number of distinct shares that appear to
+            # be available.
             if len(sharemap) >= self._required_shares:
                 # this one looks retrievable
-                d = defer.maybeDeferred(self._extract_data, verinfo, sharemap)
+                d = defer.maybeDeferred(self._attempt_decode, verinfo, sharemap)
                 def _problem(f):
+                    self._last_failure = f
                     if f.check(CorruptShareError):
                         # log(WEIRD)
-                        # _extract_data is responsible for removing the bad
+                        # _attempt_decode is responsible for removing the bad
                         # share, so we can just try again
                         eventually(self._check_for_done)
                         return
@@ -416,41 +444,45 @@ class Retrieve:
             return
 
         # we've used up all the peers we're allowed to search. Failure.
-        return self._done(failure.Failure(NotEnoughPeersError()))
+        e = NotEnoughPeersError("last failure: %s" % self._last_failure)
+        return self._done(failure.Failure(e))
 
-    def _extract_data(self, verinfo, sharemap):
+    def _attempt_decode(self, verinfo, sharemap):
         # sharemap is a dict which maps shnum to [(peerid,data)..] sets.
         (seqnum, root_hash, IV) = verinfo
 
         # first, validate each share that we haven't validated yet. We use
         # self._valid_shares to remember which ones we've already checked.
 
-        self._valid_shares = set()  # set of (peerid,data) sets
         shares = {}
-        for shnum, shareinfo in sharemap.items():
-            if shareinfo not in self._valid_shares:
-                (peerid,data) = shareinfo
-                try:
-                    # The (seqnum+root_hash+IV) tuple for this share was
-                    # already verified: specifically, all shares in the
-                    # sharemap have a (seqnum+root_hash+IV) pair that was
-                    # present in a validly signed prefix. The remainder of
-                    # the prefix for this particular share has *not* been
-                    # validated, but we don't care since we don't use it.
-                    # self._validate_share() is required to check the hashes
-                    # on the share data (and hash chains) to make sure they
-                    # match root_hash, but is not required (and is in fact
-                    # prohibited, because we don't validate the prefix on all
-                    # shares) from using anything else in the share.
-                    sharedata = self._validate_share(root_hash, shnum, data)
-                except CorruptShareError, e:
-                    self.log("WEIRD: share was corrupt: %s" % e)
-                    sharemap[shnum].discard(shareinfo)
-                    # If there are enough remaining shares, _check_for_done()
-                    # will try again
-                    raise
-                self._valid_shares.add(shareinfo)
-                shares[shnum] = sharedata
+        for shnum, shareinfos in sharemap.items():
+            for shareinfo in shareinfos:
+                if shareinfo not in self._valid_shares:
+                    (peerid,data) = shareinfo
+                    try:
+                        # The (seqnum+root_hash+IV) tuple for this share was
+                        # already verified: specifically, all shares in the
+                        # sharemap have a (seqnum+root_hash+IV) pair that was
+                        # present in a validly signed prefix. The remainder
+                        # of the prefix for this particular share has *not*
+                        # been validated, but we don't care since we don't
+                        # use it. self._validate_share() is required to check
+                        # the hashes on the share data (and hash chains) to
+                        # make sure they match root_hash, but is not required
+                        # (and is in fact prohibited, because we don't
+                        # validate the prefix on all shares) from using
+                        # anything else in the share.
+                        validator = self._validate_share_and_extract_data
+                        sharedata = validator(root_hash, shnum, data)
+                        assert isinstance(sharedata, str)
+                    except CorruptShareError, e:
+                        self.log("WEIRD: share was corrupt: %s" % e)
+                        sharemap[shnum].discard(shareinfo)
+                        # If there are enough remaining shares,
+                        # _check_for_done() will try again
+                        raise
+                    self._valid_shares.add(shareinfo)
+                    shares[shnum] = sharedata
         # at this point, all shares in the sharemap are valid, and they're
         # all for the same seqnum+root_hash version, so it's now down to
         # doing FEC and decrypt.
@@ -458,7 +490,36 @@ class Retrieve:
         d.addCallback(self._decrypt, IV)
         return d
 
+    def _validate_share_and_extract_data(self, root_hash, shnum, data):
+        # 'data' is the whole SMDF share
+        self.log("_validate_share_and_extract_data[%d]" % shnum)
+        assert data[0] == "\x00"
+        pieces = unpack_share(data)
+        (seqnum, root_hash, IV, k, N, segsize, datalen,
+         pubkey, signature, share_hash_chain, block_hash_tree,
+         share_data, enc_privkey) = pieces
+
+        assert isinstance(share_data, str)
+        # build the block hash tree. SDMF has only one leaf.
+        leaves = [hashutil.block_hash(share_data)]
+        t = hashtree.HashTree(leaves)
+        if list(t) != block_hash_tree:
+            raise CorruptShareError("block hash tree failure")
+        share_hash_leaf = t[0]
+        # t2 = hashtree.IncompleteHashTree()
+        # TODO: use shnum, share_hash_leaf, share_hash_chain to compare against
+        # root_hash
+        #if False:
+        #    raise CorruptShareError("explanation")
+        self.log(" data valid! len=%d" % len(share_data))
+        return share_data
+
     def _decode(self, shares_dict):
+        # we ought to know these values by now
+        assert self._segsize is not None
+        assert self._required_shares is not None
+        assert self._total_shares is not None
+
         # shares_dict is a dict mapping shnum to share data, but the codec
         # wants two lists.
         shareids = []; shares = []
@@ -466,21 +527,29 @@ class Retrieve:
             shareids.append(shareid)
             shares.append(share)
 
+        # zfec really doesn't want extra shares
+        shareids = shareids[:self._required_shares]
+        shares = shares[:self._required_shares]
+
         fec = codec.CRSDecoder()
-        # we ought to know these values by now
-        assert self._segsize is not None
-        assert self._required_shares is not None
-        assert self._total_shares is not None
         params = "%d-%d-%d" % (self._segsize,
                                self._required_shares, self._total_shares)
         fec.set_serialized_params(params)
 
-        d = fec.decode(shares, shareids)
+        self.log("params %s, we have %d shares" % (params, len(shares)))
+        self.log("about to decode, shareids=%s" % (shareids,))
+        d = defer.maybeDeferred(fec.decode, shares, shareids)
         def _done(buffers):
+            self.log(" decode done, %d buffers" % len(buffers))
             segment = "".join(buffers)
             segment = segment[:self._datalength]
+            self.log(" segment len=%d" % len(segment))
             return segment
+        def _err(f):
+            self.log(" decode failed: %s" % f)
+            return f
         d.addCallback(_done)
+        d.addErrback(_err)
         return d
 
     def _decrypt(self, crypttext, IV):
@@ -490,6 +559,7 @@ class Retrieve:
         return plaintext
 
     def _done(self, contents):
+        self.log("DONE, contents: %r" % contents)
         self._running = False
         eventually(self._done_deferred.callback, contents)
 
@@ -508,6 +578,10 @@ class Publish:
     def __init__(self, filenode):
         self._node = filenode
 
+    def log(self, msg):
+        prefix = idlib.b2a(self._node.get_storage_index())[:6]
+        #self._node._client.log("%s: %s" % (prefix, msg))
+
     def publish(self, newdata):
         """Publish the filenode's current contents. Returns a Deferred that
         fires (with None) when the publish has done as much work as it's ever
@@ -523,6 +597,8 @@ class Publish:
         # 4a: may need to run recovery algorithm
         # 5: when enough responses are back, we're done
 
+        self.log("starting publish")
+
         old_roothash = self._node._current_roothash
         old_seqnum = self._node._current_seqnum
 
@@ -549,6 +625,8 @@ class Publish:
 
     def _encrypt_and_encode(self, newdata, readkey, IV,
                             required_shares, total_shares):
+        self.log("_encrypt_and_encode")
+
         key = hashutil.ssk_readkey_data_hash(IV, readkey)
         enc = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16)
         crypttext = enc.encrypt(newdata)
@@ -583,6 +661,7 @@ class Publish:
                                 required_shares, total_shares,
                                 segment_size, data_length, IV),
                          seqnum, privkey, encprivkey, pubkey):
+        self.log("_generate_shares")
 
         (shares, share_ids) = shares_and_shareids
 
@@ -655,6 +734,8 @@ class Publish:
 
 
     def _query_peers(self, (seqnum, root_hash, final_shares), total_shares):
+        self.log("_query_peers")
+
         self._new_seqnum = seqnum
         self._new_root_hash = root_hash
         self._new_shares = final_shares
@@ -700,6 +781,8 @@ class Publish:
 
     def _got_query_results(self, datavs, peerid, permutedid,
                            reachable_peers, current_share_peers):
+        self.log("_got_query_results")
+
         assert isinstance(datavs, dict)
         reachable_peers[peerid] = permutedid
         for shnum, datav in datavs.items():
@@ -712,6 +795,7 @@ class Publish:
     def _got_all_query_results(self, res,
                                total_shares, reachable_peers, new_seqnum,
                                current_share_peers, peer_storage_servers):
+        self.log("_got_all_query_results")
         # now that we know everything about the shares currently out there,
         # decide where to place the new shares.
 
@@ -758,6 +842,7 @@ class Publish:
         return (target_map, peer_storage_servers)
 
     def _send_shares(self, (target_map, peer_storage_servers), IV ):
+        self.log("_send_shares")
         # we're finally ready to send out our shares. If we encounter any
         # surprises here, it's because somebody else is writing at the same
         # time. (Note: in the future, when we remove the _query_peers() step
@@ -821,6 +906,7 @@ class Publish:
     def _got_write_answer(self, answer, tw_vectors, my_checkstring,
                           peerid, expected_old_shares,
                           dispatch_map):
+        self.log("_got_write_answer: %r" % (answer,))
         wrote, read_data = answer
         surprised = False
 
@@ -851,6 +937,7 @@ class Publish:
             self._surprised = True
 
     def _maybe_recover(self, (surprised, dispatch_map)):
+        self.log("_maybe_recover")
         if not surprised:
             return
         print "RECOVERY NOT YET IMPLEMENTED"
@@ -886,6 +973,7 @@ class MutableFileNode:
         self._writekey = self._uri.writekey
         self._readkey = self._uri.readkey
         self._storage_index = self._uri.storage_index
+        self._fingerprint = self._uri.fingerprint
         return self
 
     def create(self, initial_contents):
@@ -996,9 +1084,8 @@ class MutableFileNode:
         raise NotImplementedError
 
     def download_to_data(self):
-        #downloader = self._client.getServiceNamed("downloader")
-        #return downloader.download_to_data(self.uri)
-        return defer.succeed("this isn't going to fool you, is it")
+        r = Retrieve(self)
+        return r.retrieve()
 
     def replace(self, newdata):
         return defer.succeed(None)
diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py
index a58b1389..6cf99eaf 100644
--- a/src/allmydata/test/test_system.py
+++ b/src/allmydata/test/test_system.py
@@ -241,7 +241,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
 
     def test_mutable(self):
         self.basedir = "system/SystemTest/test_mutable"
-        DATA = "Some data to upload\n" * 200
+        DATA = "initial contents go here."  # 25 bytes % 3 != 0
         d = self.set_up_nodes()
 
         def _create_mutable(res):
@@ -249,10 +249,12 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
             #print "CREATING MUTABLE FILENODE"
             c = self.clients[0]
             n = MutableFileNode(c)
-            d1 = n.create("initial contents go here.") # 25 bytes % 3 != 0
+            d1 = n.create(DATA)
             def _done(res):
                 log.msg("DONE: %s" % (res,))
-                #print "DONE", res
+                self._mutable_node_1 = res
+                uri = res.get_uri()
+                #print "DONE", uri
             d1.addBoth(_done)
             return d1
         d.addCallback(_create_mutable)
@@ -314,6 +316,41 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
                 raise
         d.addCallback(_test_debug)
 
+        # test retrieval
+
+        # first, let's see if we can use the existing node to retrieve the
+        # contents. This allows it to use the cached pubkey and maybe the
+        # latest-known sharemap.
+
+        d.addCallback(lambda res: self._mutable_node_1.download_to_data())
+        def _check_download_1(res):
+            #print "_check_download_1"
+            self.failUnlessEqual(res, DATA)
+            # now we see if we can retrieve the data from a new node,
+            # constructed using the URI of the original one. We do this test
+            # on the same client that uploaded the data.
+            #print "download1 good, starting download2"
+            uri = self._mutable_node_1.get_uri()
+            newnode = self.clients[0].create_mutable_file_from_uri(uri)
+            return newnode.download_to_data()
+            return d
+        d.addCallback(_check_download_1)
+
+        def _check_download_2(res):
+            #print "_check_download_2"
+            self.failUnlessEqual(res, DATA)
+            # same thing, but with a different client
+            #print "starting download 3"
+            uri = self._mutable_node_1.get_uri()
+            newnode = self.clients[1].create_mutable_file_from_uri(uri)
+            return newnode.download_to_data()
+        d.addCallback(_check_download_2)
+
+        def _check_download_3(res):
+            #print "_check_download_3"
+            self.failUnlessEqual(res, DATA)
+        d.addCallback(_check_download_3)
+
         return d
 
     def flip_bit(self, good):
-- 
2.45.2