From 2ed394e471165eb688457b15342e430a8dad0e20 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Tue, 6 Nov 2007 15:04:46 -0700
Subject: [PATCH] mutable: move IV into signed prefix, add more retrieval code

---
 docs/mutable.txt                   |  35 +++--
 src/allmydata/mutable.py           | 198 ++++++++++++++++++++---------
 src/allmydata/test/test_mutable.py |  14 +-
 3 files changed, 163 insertions(+), 84 deletions(-)

diff --git a/docs/mutable.txt b/docs/mutable.txt
index f1cc6c28..2d64d60d 100644
--- a/docs/mutable.txt
+++ b/docs/mutable.txt
@@ -378,27 +378,26 @@ offset is used both to terminate the share data and to begin the encprivkey).
  1    0        1       version byte, \x00 for this format
  2    1        8       sequence number. 2^64-1 must be handled specially, TBD
  3    9        32      "R" (root of share hash Merkle tree)
- 4    41       18      encoding parameters:
-       41       1        k
-       42       1        N
-       43       8        segment size
-       51       8        data length (of original plaintext)
- 5    59       36      offset table:
-       59       4        (7) signature
-       63       4        (8) share hash chain
-       67       4        (9) block hash tree
-       71       4        (10) IV
-       75       4        (11) share data
-       79       8        (12) encrypted private key
-       87       8        (13) EOF
- 6    95       256     verification key (2048 RSA key 'n' value, e=3)
- 7    361      256     signature= RSAenc(sig-key, H(version+seqnum+r+encparm))
- 8    607      (a)     share hash chain, encoded as:
+ 4    41       16      IV (share data is AES(H(readkey+IV)) )
+ 5    57       18      encoding parameters:
+       57       1        k
+       58       1        N
+       59       8        segment size
+       67       8        data length (of original plaintext)
+ 6    75       36      offset table:
+       75       4        (8) signature
+       79       4        (9) share hash chain
+       83       4        (10) block hash tree
+       91       4        (11) share data
+       95       8        (12) encrypted private key
+       103      8        (13) EOF
+ 7    111      256     verification key (2048 RSA key 'n' value, e=3)
+ 8    367      256     signature=RSAenc(sigkey, H(version+seqnum+r+IV+encparm))
+ 9    623      (a)     share hash chain, encoded as:
                         "".join([pack(">H32s", shnum, hash)
                                  for (shnum,hash) in needed_hashes])
- 9    ??       (b)     block hash tree, encoded as:
+10    ??       (b)     block hash tree, encoded as:
                         "".join([pack(">32s",hash) for hash in block_hash_tree])
-10    ??       16      IV (share data is AES(H(readkey+IV)) )
 11    ??       LEN     share data (no gap between this and encprivkey)
 12    ??       256     encrypted private key= AESenc(write-key, RSA 'd' value)
 13    ??       --      EOF
diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py
index a1f48a3d..e2bd556a 100644
--- a/src/allmydata/mutable.py
+++ b/src/allmydata/mutable.py
@@ -34,25 +34,27 @@ class CorruptShareError(Exception):
                                                                self.shnum,
                                                                self.reason)
 
-HEADER_LENGTH = struct.calcsize(">BQ32s BBQQ LLLLLQQ")
+PREFIX = ">BQ32s16s" # each version has a different prefix
+SIGNED_PREFIX = ">BQ32s16s BBQQ" # this is covered by the signature
+HEADER = ">BQ32s16s BBQQ LLLLQQ" # includes offsets
+HEADER_LENGTH = struct.calcsize(HEADER)
 
 def unpack_prefix_and_signature(data):
     assert len(data) >= HEADER_LENGTH
     o = {}
-    prefix = data[:struct.calcsize(">BQ32s BBQQ")]
+    prefix = data[:struct.calcsize(SIGNED_PREFIX)]
 
     (version,
      seqnum,
      root_hash,
+     IV,
      k, N, segsize, datalen,
      o['signature'],
      o['share_hash_chain'],
      o['block_hash_tree'],
-     o['IV'],
      o['share_data'],
      o['enc_privkey'],
-     o['EOF']) = struct.unpack(">BQ32s BBQQ LLLLLQQ",
-                               data[:HEADER_LENGTH])
+     o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH])
 
     assert version == 0
     if len(data) < o['share_hash_chain']:
@@ -61,7 +63,7 @@ def unpack_prefix_and_signature(data):
     pubkey_s = data[HEADER_LENGTH:o['signature']]
     signature = data[o['signature']:o['share_hash_chain']]
 
-    return (seqnum, root_hash, k, N, segsize, datalen,
+    return (seqnum, root_hash, IV, k, N, segsize, datalen,
             pubkey_s, signature, prefix)
 
 def unpack_share(data):
@@ -70,15 +72,14 @@ def unpack_share(data):
     (version,
      seqnum,
      root_hash,
+     IV,
      k, N, segsize, datalen,
      o['signature'],
      o['share_hash_chain'],
      o['block_hash_tree'],
-     o['IV'],
      o['share_data'],
      o['enc_privkey'],
-     o['EOF']) = struct.unpack(">BQ32s" + "BBQQ" + "LLLLLQQ",
-                                     data[:HEADER_LENGTH])
+     o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH])
 
     assert version == 0
     if len(data) < o['EOF']:
@@ -95,41 +96,41 @@ def unpack_share(data):
         chunk = share_hash_chain_s[i:i+hsize]
         (hid, h) = struct.unpack(share_hash_format, chunk)
         share_hash_chain.append( (hid, h) )
-    block_hash_tree_s = data[o['block_hash_tree']:o['IV']]
+    block_hash_tree_s = data[o['block_hash_tree']:o['share_data']]
     assert len(block_hash_tree_s) % 32 == 0, len(block_hash_tree_s)
     block_hash_tree = []
     for i in range(0, len(block_hash_tree_s), 32):
         block_hash_tree.append(block_hash_tree_s[i:i+32])
 
-    IV = data[o['IV']:o['share_data']]
     share_data = data[o['share_data']:o['enc_privkey']]
     enc_privkey = data[o['enc_privkey']:o['EOF']]
 
-    return (seqnum, root_hash, k, N, segsize, datalen,
+    return (seqnum, root_hash, IV, k, N, segsize, datalen,
             pubkey, signature, share_hash_chain, block_hash_tree,
-            IV, share_data, enc_privkey)
+            share_data, enc_privkey)
 
 
-def pack_checkstring(seqnum, root_hash):
-    return struct.pack(">BQ32s",
+def pack_checkstring(seqnum, root_hash, IV):
+    return struct.pack(PREFIX,
                        0, # version,
                        seqnum,
-                       root_hash)
+                       root_hash,
+                       IV)
 
 def unpack_checkstring(checkstring):
-    cs_len = struct.calcsize(">BQ32s")
-    version, seqnum, root_hash = struct.unpack(">BQ32s",
-                                               checkstring[:cs_len])
+    cs_len = struct.calcsize(PREFIX)
+    version, seqnum, root_hash, IV = struct.unpack(PREFIX, checkstring[:cs_len])
     assert version == 0 # TODO: just ignore the share
-    return (seqnum, root_hash)
+    return (seqnum, root_hash, IV)
 
-def pack_prefix(seqnum, root_hash,
+def pack_prefix(seqnum, root_hash, IV,
                 required_shares, total_shares,
                 segment_size, data_length):
-    prefix = struct.pack(">BQ32s" + "BBQQ",
+    prefix = struct.pack(SIGNED_PREFIX,
                          0, # version,
                          seqnum,
                          root_hash,
+                         IV,
 
                          required_shares,
                          total_shares,
@@ -140,23 +141,20 @@ def pack_prefix(seqnum, root_hash,
 
 def pack_offsets(verification_key_length, signature_length,
                  share_hash_chain_length, block_hash_tree_length,
-                 IV_length, share_data_length, encprivkey_length):
+                 share_data_length, encprivkey_length):
     post_offset = HEADER_LENGTH
     offsets = {}
     o1 = offsets['signature'] = post_offset + verification_key_length
     o2 = offsets['share_hash_chain'] = o1 + signature_length
     o3 = offsets['block_hash_tree'] = o2 + share_hash_chain_length
-    assert IV_length == 16
-    o4 = offsets['IV'] = o3 + block_hash_tree_length
-    o5 = offsets['share_data'] = o4 + IV_length
-    o6 = offsets['enc_privkey'] = o5 + share_data_length
-    o7 = offsets['EOF'] = o6 + encprivkey_length
+    o4 = offsets['share_data'] = o3 + block_hash_tree_length
+    o5 = offsets['enc_privkey'] = o4 + share_data_length
+    o6 = offsets['EOF'] = o5 + encprivkey_length
 
-    return struct.pack(">LLLLLQQ",
+    return struct.pack(">LLLLQQ",
                        offsets['signature'],
                        offsets['share_hash_chain'],
                        offsets['block_hash_tree'],
-                       offsets['IV'],
                        offsets['share_data'],
                        offsets['enc_privkey'],
                        offsets['EOF'])
@@ -302,6 +300,10 @@ class Retrieve:
         # we'll grab a copy from the first peer we talk to.
         self._pubkey = filenode.get_pubkey()
         self._storage_index = filenode.get_storage_index()
+        self._readkey = filenode.get_readkey()
+
+    def log(self, msg):
+        self._node._client.log(msg)
 
     def retrieve(self):
         """Retrieve the filenode's current contents. Returns a Deferred that
@@ -415,6 +417,11 @@ class Retrieve:
         # TODO
         return None
 
+    def _validate_share(self, root_hash, shnum, data):
+        if False:
+            raise CorruptShareError("explanation")
+        pass
+
     def _got_results(self, datavs, peerid, readsize):
         self._queries_outstanding.discard(peerid)
         self._used_peers.add(peerid)
@@ -423,7 +430,7 @@ class Retrieve:
 
         for shnum,datav in datavs.items():
             data = datav[0]
-            (seqnum, root_hash, k, N, segsize, datalength,
+            (seqnum, root_hash, IV, k, N, segsize, datalength,
              pubkey_s, signature, prefix) = unpack_prefix_and_signature(data)
 
             if not self._pubkey:
@@ -434,7 +441,7 @@ class Retrieve:
                                             "pubkey doesn't match fingerprint")
                 self._pubkey = self._deserialize_pubkey(pubkey_s)
 
-            verinfo = (seqnum, root_hash)
+            verinfo = (seqnum, root_hash, IV)
             if verinfo not in self._valid_versions:
                 # it's a new pair. Verify the signature.
                 valid = self._pubkey.verify(prefix, signature)
@@ -480,28 +487,29 @@ class Retrieve:
         self._bad_peerids.add(peerid)
         short_sid = idlib.a2b(self.storage_index)[:6]
         if f.check(CorruptShareError):
-            self._node._client.log("WEIRD: bad share for %s: %s" %
-                                   (short_sid, f))
+            self.log("WEIRD: bad share for %s: %s" % (short_sid, f))
         else:
-            self._node._client.log("WEIRD: other error for %s: %s" %
-                                   (short_sid, f))
+            self.log("WEIRD: other error for %s: %s" % (short_sid, f))
         self._check_for_done()
 
     def _check_for_done(self):
         share_prefixes = {}
         versionmap = DictOfSets()
-        for prefix, sharemap in self._valid_versions.values():
+        for verinfo, (prefix, sharemap) in self._valid_versions.items():
             if len(sharemap) >= self._required_shares:
                 # this one looks retrievable
-                try:
-                    contents = self._extract_data(sharemap)
-                except CorruptShareError:
-                    # log(WEIRD)
-                    # _extract_data is responsible for removing the bad
-                    # share, so we can just try again
-                    return self._check_for_done()
-                # success!
-                return self._done(contents)
+                d = defer.maybeDeferred(self._extract_data, verinfo, sharemap)
+                def _problem(f):
+                    if f.check(CorruptShareError):
+                        # log(WEIRD)
+                        # _extract_data is responsible for removing the bad
+                        # share, so we can just try again
+                        eventually(self._check_for_done)
+                        return
+                    return f
+                d.addCallbacks(self._done, _problem)
+                return
+
         # we don't have enough shares yet. Should we send out more queries?
         if self._queries_outstanding:
             # there are some running, so just wait for them to come back.
@@ -544,6 +552,77 @@ class Retrieve:
         # we've used up all the peers we're allowed to search. Failure.
         return self._done(failure.Failure(NotEnoughPeersError()))
 
+    def _extract_data(self, verinfo, sharemap):
+        # sharemap is a dict which maps shnum to [(peerid,data)..] sets.
+        (seqnum, root_hash, IV) = verinfo
+
+        # first, validate each share that we haven't validated yet. We use
+        # self._valid_shares to remember which ones we've already checked.
+
+        self._valid_shares = set()  # set of (peerid,data) sets
+        shares = {}
+        for shnum, shareinfo in sharemap.items():
+            if shareinfo not in self._valid_shares:
+                (peerid,data) = shareinfo
+                try:
+                    # The (seqnum+root_hash+IV) tuple for this share was
+                    # already verified: specifically, all shares in the
+                    # sharemap have a (seqnum+root_hash+IV) pair that was
+                    # present in a validly signed prefix. The remainder of
+                    # the prefix for this particular share has *not* been
+                    # validated, but we don't care since we don't use it.
+                    # self._validate_share() is required to check the hashes
+                    # on the share data (and hash chains) to make sure they
+                    # match root_hash, but is not required (and is in fact
+                    # prohibited, because we don't validate the prefix on all
+                    # shares) from using anything else in the share.
+                    sharedata = self._validate_share(root_hash, shnum, data)
+                except CorruptShareError, e:
+                    self.log("WEIRD: share was corrupt: %s" % e)
+                    sharemap[shnum].discard(shareinfo)
+                    # If there are enough remaining shares, _check_for_done()
+                    # will try again
+                    raise
+                self._valid_shares.add(shareinfo)
+                shares[shnum] = sharedata
+        # at this point, all shares in the sharemap are valid, and they're
+        # all for the same seqnum+root_hash version, so it's now down to
+        # doing FEC and decrypt.
+        d = defer.maybeDeferred(self._decode, shares)
+        d.addCallback(self._decrypt, IV)
+        return d
+
+    def _decode(self, shares_dict):
+        # shares_dict is a dict mapping shnum to share data, but the codec
+        # wants two lists.
+        shareids = []; shares = []
+        for shareid, share in shares_dict.items():
+            shareids.append(shareid)
+            shares.append(share)
+
+        fec = codec.CRSDecoder()
+        # we ought to know these values by now
+        assert self._segsize is not None
+        assert self._required_shares is not None
+        assert self._total_shares is not None
+        params = "%d-%d-%d" % (self._segsize,
+                               self._required_shares, self._total_shares)
+        fec.set_serialized_params(params)
+
+        d = fec.decode(shares, shareids)
+        def _done(buffers):
+            segment = "".join(buffers)
+            segment = segment[:self._datalength]
+            return segment
+        d.addCallback(_done)
+        return d
+
+    def _decrypt(self, crypttext, IV):
+        key = hashutil.ssk_readkey_data_hash(IV, self._readkey)
+        decryptor = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16)
+        plaintext = decryptor.decrypt(crypttext)
+        return plaintext
+
     def _done(self, contents):
         self._running = False
         eventually(self._done_deferred.callback, contents)
@@ -588,21 +667,22 @@ class Publish:
         encprivkey = self._node.get_encprivkey()
         pubkey = self._node.get_pubkey()
 
+        IV = os.urandom(16)
+
         d = defer.succeed(newdata)
-        d.addCallback(self._encrypt_and_encode, readkey,
+        d.addCallback(self._encrypt_and_encode, readkey, IV,
                       required_shares, total_shares)
         d.addCallback(self._generate_shares, old_seqnum+1,
                       privkey, encprivkey, pubkey)
 
         d.addCallback(self._query_peers, total_shares)
-        d.addCallback(self._send_shares)
+        d.addCallback(self._send_shares, IV)
         d.addCallback(self._maybe_recover)
         d.addCallback(lambda res: None)
         return d
 
-    def _encrypt_and_encode(self, newdata, readkey,
+    def _encrypt_and_encode(self, newdata, readkey, IV,
                             required_shares, total_shares):
-        IV = os.urandom(16)
         key = hashutil.ssk_readkey_data_hash(IV, readkey)
         enc = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16)
         crypttext = enc.encrypt(newdata)
@@ -666,7 +746,7 @@ class Publish:
         root_hash = share_hash_tree[0]
         assert len(root_hash) == 32
 
-        prefix = pack_prefix(seqnum, root_hash,
+        prefix = pack_prefix(seqnum, root_hash, IV,
                              required_shares, total_shares,
                              segment_size, data_length)
 
@@ -694,7 +774,6 @@ class Publish:
                                    len(signature),
                                    len(share_hash_chain_s),
                                    len(block_hash_tree_s),
-                                   len(IV),
                                    len(share_data),
                                    len(encprivkey))
 
@@ -704,7 +783,6 @@ class Publish:
                                            signature,
                                            share_hash_chain_s,
                                            block_hash_tree_s,
-                                           IV,
                                            share_data,
                                            encprivkey])
         return (seqnum, root_hash, final_shares)
@@ -812,7 +890,7 @@ class Publish:
 
         return (target_map, peer_storage_servers)
 
-    def _send_shares(self, (target_map, peer_storage_servers) ):
+    def _send_shares(self, (target_map, peer_storage_servers), IV ):
         # we're finally ready to send out our shares. If we encounter any
         # surprises here, it's because somebody else is writing at the same
         # time. (Note: in the future, when we remove the _query_peers() step
@@ -821,7 +899,7 @@ class Publish:
         # and we'll need to respond to them more gracefully.
 
         my_checkstring = pack_checkstring(self._new_seqnum,
-                                          self._new_root_hash)
+                                          self._new_root_hash, IV)
         peer_messages = {}
         expected_old_shares = {}
 
@@ -884,14 +962,14 @@ class Publish:
             surprised = True
 
         for shnum, (old_cs,) in read_data.items():
-            old_seqnum, old_root_hash = unpack_checkstring(old_cs)
+            (old_seqnum, old_root_hash, IV) = unpack_checkstring(old_cs)
             if wrote and shnum in tw_vectors:
-                current_cs = my_checkstring
+                cur_cs = my_checkstring
             else:
-                current_cs = old_cs
+                cur_cs = old_cs
 
-            current_seqnum, current_root_hash = unpack_checkstring(current_cs)
-            dispatch_map.add(shnum, (peerid, current_seqnum, current_root_hash))
+            (cur_seqnum, cur_root_hash, IV) = unpack_checkstring(cur_cs)
+            dispatch_map.add(shnum, (peerid, cur_seqnum, cur_root_hash))
 
             if shnum not in expected_old_shares:
                 # surprise! there was a share we didn't know about
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index ea5bb411..4176b813 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -167,7 +167,7 @@ class Publish(unittest.TestCase):
         fn.create(CONTENTS)
         p = mutable.Publish(fn)
         d = defer.maybeDeferred(p._encrypt_and_encode,
-                                CONTENTS, "READKEY", 3, 10)
+                                CONTENTS, "READKEY", "IV"*8, 3, 10)
         def _done( ((shares, share_ids),
                     required_shares, total_shares,
                     segsize, data_length, IV) ):
@@ -212,12 +212,12 @@ class Publish(unittest.TestCase):
             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
             for i,sh in final_shares.items():
                 self.failUnless(isinstance(sh, str))
-                self.failUnlessEqual(len(sh), 369)
+                self.failUnlessEqual(len(sh), 381)
                 # feed the share through the unpacker as a sanity-check
                 pieces = mutable.unpack_share(sh)
-                (u_seqnum, u_root_hash, k, N, segsize, datalen,
+                (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
                  pubkey, signature, share_hash_chain, block_hash_tree,
-                 IV, share_data, enc_privkey) = pieces
+                 share_data, enc_privkey) = pieces
                 self.failUnlessEqual(u_seqnum, 3)
                 self.failUnlessEqual(u_root_hash, root_hash)
                 self.failUnlessEqual(k, 3)
@@ -225,7 +225,8 @@ class Publish(unittest.TestCase):
                 self.failUnlessEqual(segsize, 21)
                 self.failUnlessEqual(datalen, len(CONTENTS))
                 self.failUnlessEqual(pubkey, FakePubKey(0).serialize())
-                sig_material = struct.pack(">BQ32s BBQQ", 0, seqnum, root_hash,
+                sig_material = struct.pack(">BQ32s16s BBQQ",
+                                           0, seqnum, root_hash, IV,
                                            k, N, segsize, datalen)
                 self.failUnlessEqual(signature,
                                      FakePrivKey(0).sign(sig_material))
@@ -355,7 +356,8 @@ class Publish(unittest.TestCase):
         total_shares = 10
         d, p = self.setup_for_write(20, total_shares)
         d.addCallback(p._query_peers, total_shares)
-        d.addCallback(p._send_shares)
+        IV = "IV"*8
+        d.addCallback(p._send_shares, IV)
         def _done((surprised, dispatch_map)):
             self.failIf(surprised, "surprised!")
         d.addCallback(_done)
-- 
2.45.2