]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
mutable: grab encprivkey when necessary during publish, fix test_mutable
authorBrian Warner <warner@lothar.com>
Thu, 8 Nov 2007 09:46:27 +0000 (02:46 -0700)
committerBrian Warner <warner@lothar.com>
Thu, 8 Nov 2007 09:46:27 +0000 (02:46 -0700)
docs/mutable.txt
src/allmydata/mutable.py
src/allmydata/test/test_mutable.py
src/allmydata/util/idlib.py

index a67e8973ec5ee7bd8ca63380915348f02ede4188..ee01d0449ac775ee1e54caea728d10a322f51da9 100644 (file)
@@ -391,15 +391,15 @@ offset is used both to terminate the share data and to begin the encprivkey).
        91       4        (11) share data
        95       8        (12) encrypted private key
        103      8        (13) EOF
- 7    111      256     verification key (2048 RSA key 'n' value, e=3)
- 8    367      256     signature=RSAenc(sigkey, H(version+seqnum+r+IV+encparm))
- 9    623      (a)     share hash chain, encoded as:
+ 7    111      292ish  verification key (2048 RSA key)
+ 8    367ish   256ish  signature=RSAenc(sigkey, H(version+seqnum+r+IV+encparm))
+ 9    623ish   (a)     share hash chain, encoded as:
                         "".join([pack(">H32s", shnum, hash)
                                  for (shnum,hash) in needed_hashes])
 10    ??       (b)     block hash tree, encoded as:
                         "".join([pack(">32s",hash) for hash in block_hash_tree])
 11    ??       LEN     share data (no gap between this and encprivkey)
-12    ??       256     encrypted private key= AESenc(write-key, RSA 'd' value)
+12    ??       1216ish encrypted private key= AESenc(write-key, RSA-key)
 13    ??       --      EOF
 
 (a) The share hash chain contains ceil(log(N)) hashes, each 32 bytes long.
index a3d93ab50174572442ac97a9abff89c6ad559e1f..7aea078fbd5e9212add1307f9d53a90314aff4bd 100644 (file)
@@ -15,9 +15,11 @@ from pycryptopp.publickey import rsa
 
 
 class NeedMoreDataError(Exception):
-    def __init__(self, needed_bytes):
+    def __init__(self, needed_bytes, encprivkey_offset, encprivkey_length):
         Exception.__init__(self)
-        self.needed_bytes = needed_bytes
+        self.needed_bytes = needed_bytes # up through EOF
+        self.encprivkey_offset = encprivkey_offset
+        self.encprivkey_length = encprivkey_length
     def __str__(self):
         return "<NeedMoreDataError (%d bytes)>" % self.needed_bytes
 
@@ -59,7 +61,8 @@ def unpack_prefix_and_signature(data):
 
     assert version == 0
     if len(data) < o['share_hash_chain']:
-        raise NeedMoreDataError(o['share_hash_chain'])
+        raise NeedMoreDataError(o['share_hash_chain'],
+                                o['enc_privkey'], o['EOF']-o['enc_privkey'])
 
     pubkey_s = data[HEADER_LENGTH:o['signature']]
     signature = data[o['signature']:o['share_hash_chain']]
@@ -84,7 +87,8 @@ def unpack_share(data):
 
     assert version == 0
     if len(data) < o['EOF']:
-        raise NeedMoreDataError(o['EOF'])
+        raise NeedMoreDataError(o['EOF'],
+                                o['enc_privkey'], o['EOF']-o['enc_privkey'])
 
     pubkey = data[HEADER_LENGTH:o['signature']]
     signature = data[o['signature']:o['share_hash_chain']]
@@ -593,7 +597,10 @@ class Publish:
 
     def log(self, msg):
         prefix = idlib.b2a(self._node.get_storage_index())[:6]
-        #self._node._client.log("%s: %s" % (prefix, msg))
+        self._node._client.log("%s: %s" % (prefix, msg))
+
+    def log_err(self, f):
+        log.err(f)
 
     def publish(self, newdata):
         """Publish the filenode's current contents. Returns a Deferred that
@@ -612,6 +619,10 @@ class Publish:
 
         self.log("starting publish")
 
+        self._storage_index = self._node.get_storage_index()
+        self._writekey = self._node.get_writekey()
+        assert self._writekey, "need write capability to publish"
+
         old_roothash = self._node._current_roothash
         old_seqnum = self._node._current_seqnum
         assert old_seqnum is not None, "must read before replace"
@@ -629,8 +640,18 @@ class Publish:
 
         IV = os.urandom(16)
 
+        # we read only 1KB because all we generally care about is the seqnum
+        # ("prefix") info, so we know which shares are where. We need to get
+        # the privkey from somebody, which means reading more like 3KB, but
+        # the code in _obtain_privkey will ensure that we manage that even if
+        # we need an extra roundtrip. TODO: arrange to read 3KB from one peer
+        # who is likely to hold a share (like, say, ourselves), so we can
+        # avoid the latency of that extra roundtrip.
+        self._read_size = 1000
+
         d = defer.succeed(total_shares)
         d.addCallback(self._query_peers)
+        d.addCallback(self._obtain_privkey)
 
         d.addCallback(self._encrypt_and_encode, newdata, readkey, IV,
                       required_shares, total_shares)
@@ -644,7 +665,7 @@ class Publish:
     def _query_peers(self, total_shares):
         self.log("_query_peers")
 
-        storage_index = self._node.get_storage_index()
+        storage_index = self._storage_index
         peerlist = self._node._client.get_permuted_peers(storage_index,
                                                          include_myself=False)
         # we don't include ourselves in the N peers, but we *do* push an
@@ -656,6 +677,8 @@ class Publish:
 
         current_share_peers = DictOfSets()
         reachable_peers = {}
+        # list of (peerid, offset, length) where the encprivkey might be found
+        self._encprivkey_shares = []
 
         EPSILON = total_shares / 2
         partial_peerlist = islice(peerlist, total_shares + EPSILON)
@@ -679,13 +702,8 @@ class Publish:
         d = conn.callRemote("get_service", "storageserver")
         def _got_storageserver(ss):
             peer_storage_servers[peerid] = ss
-            # TODO: only read 2KB, since all we really need is the seqnum
-            # info. But we need to read more from at least one peer so we can
-            # grab the encrypted privkey. Really, read just the 2k, and if
-            # the first response suggests that the privkey is beyond that
-            # segment, send out another query to the same peer for the
-            # privkey segment.
-            return ss.callRemote("slot_readv", storage_index, [], [(0, 2500)])
+            return ss.callRemote("slot_readv",
+                                 storage_index, [], [(0, self._read_size)])
         d.addCallback(_got_storageserver)
         return d
 
@@ -698,22 +716,67 @@ class Publish:
         for shnum, datav in datavs.items():
             assert len(datav) == 1
             data = datav[0]
-            r = unpack_share(data)
+            # We want (seqnum, root_hash, IV) from all servers to know what
+            # versions we are replacing. We want the encprivkey from one
+            # server (assuming it's valid) so we know our own private key, so
+            # we can sign our update. SMDF: read the whole share from each
+            # server. TODO: later we can optimize this to transfer less data.
+
+            # we assume that we have enough data to extract the signature.
+            # TODO: if this raises NeedMoreDataError, arrange to do another
+            # read pass.
+            r = unpack_prefix_and_signature(data)
             (seqnum, root_hash, IV, k, N, segsize, datalen,
-             pubkey, signature, share_hash_chain, block_hash_tree,
-             share_data, enc_privkey) = r
+             pubkey_s, signature, prefix) = r
+
+            # TODO: consider verifying the signature here. It's expensive.
+            # What can an attacker (in this case the server) accomplish? They
+            # could make us think that there's a newer version of the file
+            # out there, which would cause us to throw
+            # UncoordinatedWriteError (i.e. it's a DoS attack).
             share = (shnum, seqnum, root_hash)
             current_share_peers.add(shnum, (peerid, seqnum, root_hash) )
-            if not self._encprivkey:
-                self._encprivkey = enc_privkey
-                self._node._populate_encprivkey(self._encprivkey)
+
             if not self._privkey:
-                privkey_s = self._node._decrypt_privkey(enc_privkey)
-                self._privkey = rsa.create_signing_key_from_string(privkey_s)
-                self._node._populate_privkey(self._privkey)
-            # TODO: make sure we actually fill these in before we try to
-            # upload. This means we may need to re-fetch something if our
-            # initial read was too short.
+                self._try_to_extract_privkey(data, peerid, shnum)
+
+
+    def _try_to_extract_privkey(self, data, peerid, shnum):
+        try:
+            r = unpack_share(data)
+        except NeedMoreDataError, e:
+            # this share won't help us. oh well.
+            offset = e.encprivkey_offset
+            length = e.encprivkey_length
+            self.log("shnum %d on peerid %s: share was too short "
+                     "to get the encprivkey, but [%d:%d] ought to hold it" %
+                     (shnum, idlib.shortnodeid_b2a(peerid),
+                      offset, offset+length))
+
+            self._encprivkey_shares.append( (peerid, shnum, offset, length) )
+            return
+
+        (seqnum, root_hash, IV, k, N, segsize, datalen,
+         pubkey, signature, share_hash_chain, block_hash_tree,
+         share_data, enc_privkey) = r
+
+        return self._try_to_validate_privkey(enc_privkey, peerid, shnum)
+
+    def _try_to_validate_privkey(self, enc_privkey, peerid, shnum):
+        alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
+        alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
+        if alleged_writekey != self._writekey:
+            self.log("WEIRD: invalid privkey from %s shnum %d" %
+                     (idlib.nodeid_b2a(peerid)[:8], shnum))
+            return
+
+        # it's good
+        self.log("got valid privkey from shnum %d on peerid %s" %
+                 (shnum, idlib.shortnodeid_b2a(peerid)))
+        self._privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
+        self._encprivkey = enc_privkey
+        self._node._populate_encprivkey(self._encprivkey)
+        self._node._populate_privkey(self._privkey)
 
     def _got_all_query_results(self, res,
                                total_shares, reachable_peers,
@@ -762,9 +825,43 @@ class Publish:
 
         assert not shares_needing_homes
 
-        target_info = (target_map, peer_storage_servers)
+        target_info = (target_map, shares_per_peer, peer_storage_servers)
         return target_info
 
+    def _obtain_privkey(self, target_info):
+        # make sure we've got a copy of our private key.
+        if self._privkey:
+            # Must have picked it up during _query_peers. We're good to go.
+            return target_info
+
+        # Nope, we haven't managed to grab a copy, and we still need it. Ask
+        # peers one at a time until we get a copy. Only bother asking peers
+        # who've admitted to holding a share.
+
+        target_map, shares_per_peer, peer_storage_servers = target_info
+        # pull shares from self._encprivkey_shares
+        if not self._encprivkey_shares:
+            raise NotEnoughPeersError("Unable to find a copy of the privkey")
+
+        (peerid, shnum, offset, length) = self._encprivkey_shares.pop(0)
+        self.log("trying to obtain privkey from %s shnum %d" %
+                 (idlib.shortnodeid_b2a(peerid), shnum))
+        d = self._do_privkey_query(peer_storage_servers[peerid], peerid,
+                                   shnum, offset, length)
+        d.addErrback(self.log_err)
+        d.addCallback(lambda res: self._obtain_privkey(target_info))
+        return d
+
+    def _do_privkey_query(self, rref, peerid, shnum, offset, length):
+        d = rref.callRemote("slot_readv", self._storage_index,
+                            [shnum], [(offset, length)] )
+        d.addCallback(self._privkey_query_response, peerid, shnum)
+        return d
+
+    def _privkey_query_response(self, datav, peerid, shnum):
+        data = datav[shnum][0]
+        self._try_to_validate_privkey(data, peerid, shnum)
+
     def _encrypt_and_encode(self, target_info,
                             newdata, readkey, IV,
                             required_shares, total_shares):
@@ -893,7 +990,7 @@ class Publish:
         # surprises here are *not* indications of UncoordinatedWriteError,
         # and we'll need to respond to them more gracefully.
 
-        target_map, peer_storage_servers = target_info
+        target_map, shares_per_peer, peer_storage_servers = target_info
 
         my_checkstring = pack_checkstring(seqnum, root_hash, IV)
         peer_messages = {}
@@ -950,7 +1047,7 @@ class Publish:
     def _got_write_answer(self, answer, tw_vectors, my_checkstring,
                           peerid, expected_old_shares,
                           dispatch_map):
-        self.log("_got_write_answer: %r" % (answer,))
+        self.log("_got_write_answer from %s" % idlib.shortnodeid_b2a(peerid))
         wrote, read_data = answer
         surprised = False
 
index 42a01a660da15b29a9fdeb48866eedf348c5cbad..f3a7abfc738b66fc845602967b16d6234051749b 100644 (file)
@@ -2,7 +2,7 @@
 import itertools, struct
 from twisted.trial import unittest
 from twisted.internet import defer
-from twisted.python import failure
+from twisted.python import failure, log
 from allmydata import mutable, uri, dirnode2
 from allmydata.dirnode2 import split_netstring
 from allmydata.util.hashutil import netstring, tagged_hash
@@ -91,6 +91,8 @@ class MyClient:
         self._num_peers = num_peers
         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
                          for i in range(self._num_peers)]
+    def log(self, msg):
+        log.msg(msg)
 
     def get_renewal_secret(self):
         return "I hereby permit you to renew my files"
@@ -167,11 +169,12 @@ class Publish(unittest.TestCase):
         CONTENTS = "some initial contents"
         fn.create(CONTENTS)
         p = mutable.Publish(fn)
-        d = defer.maybeDeferred(p._encrypt_and_encode,
+        target_info = None
+        d = defer.maybeDeferred(p._encrypt_and_encode, target_info,
                                 CONTENTS, "READKEY", "IV"*8, 3, 10)
         def _done( ((shares, share_ids),
                     required_shares, total_shares,
-                    segsize, data_length, IV) ):
+                    segsize, data_length, target_info2) ):
             self.failUnlessEqual(len(shares), 10)
             for sh in shares:
                 self.failUnless(isinstance(sh, str))
@@ -181,7 +184,7 @@ class Publish(unittest.TestCase):
             self.failUnlessEqual(total_shares, 10)
             self.failUnlessEqual(segsize, 21)
             self.failUnlessEqual(data_length, len(CONTENTS))
-            self.failUnlessEqual(len(IV), 16)
+            self.failUnlessIdentical(target_info, target_info2)
         d.addCallback(_done)
         return d
 
@@ -196,16 +199,19 @@ class Publish(unittest.TestCase):
         r = mutable.Retrieve(fn)
         # make some fake shares
         shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
+        target_info = None
+        p._privkey = FakePrivKey(0)
+        p._encprivkey = "encprivkey"
+        p._pubkey = FakePubKey(0)
         d = defer.maybeDeferred(p._generate_shares,
                                 (shares_and_ids,
                                  3, 10,
                                  21, # segsize
                                  len(CONTENTS),
-                                 "IV"*8),
+                                 target_info),
                                 3, # seqnum
-                                FakePrivKey(0), "encprivkey", FakePubKey(0),
-                                )
-        def _done( (seqnum, root_hash, final_shares) ):
+                                "IV"*8)
+        def _done( (seqnum, root_hash, final_shares, target_info2) ):
             self.failUnlessEqual(seqnum, 3)
             self.failUnlessEqual(len(root_hash), 32)
             self.failUnless(isinstance(final_shares, dict))
@@ -243,6 +249,7 @@ class Publish(unittest.TestCase):
                 self.failUnlessEqual(IV, "IV"*8)
                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
                 self.failUnlessEqual(enc_privkey, "encprivkey")
+            self.failUnlessIdentical(target_info, target_info2)
         d.addCallback(_done)
         return d
 
@@ -254,6 +261,7 @@ class Publish(unittest.TestCase):
         CONTENTS = "some initial contents"
         fn.create(CONTENTS)
         p = FakePublish(fn)
+        p._storage_index = "\x00"*32
         #r = mutable.Retrieve(fn)
         p._peers = {}
         for peerid in c._peerids:
@@ -279,13 +287,10 @@ class Publish(unittest.TestCase):
     def test_sharemap_20newpeers(self):
         c, p = self.setup_for_sharemap(20)
 
-        new_seqnum = 3
-        new_root_hash = "Rnew"
-        new_shares = None
         total_shares = 10
-        d = p._query_peers( (new_seqnum, new_root_hash, new_seqnum),
-                            total_shares)
-        def _done( (target_map, peer_storage_servers) ):
+        d = p._query_peers(total_shares)
+        def _done(target_info):
+            (target_map, shares_per_peer, peer_storage_servers) = target_info
             shares_per_peer = {}
             for shnum in target_map:
                 for (peerid, old_seqnum, old_R) in target_map[shnum]:
@@ -304,13 +309,10 @@ class Publish(unittest.TestCase):
     def test_sharemap_3newpeers(self):
         c, p = self.setup_for_sharemap(3)
 
-        new_seqnum = 3
-        new_root_hash = "Rnew"
-        new_shares = None
         total_shares = 10
-        d = p._query_peers( (new_seqnum, new_root_hash, new_seqnum),
-                            total_shares)
-        def _done( (target_map, peer_storage_servers) ):
+        d = p._query_peers(total_shares)
+        def _done(target_info):
+            (target_map, shares_per_peer, peer_storage_servers) = target_info
             shares_per_peer = {}
             for shnum in target_map:
                 for (peerid, old_seqnum, old_R) in target_map[shnum]:
@@ -327,37 +329,30 @@ class Publish(unittest.TestCase):
     def test_sharemap_nopeers(self):
         c, p = self.setup_for_sharemap(0)
 
-        new_seqnum = 3
-        new_root_hash = "Rnew"
-        new_shares = None
         total_shares = 10
         d = self.shouldFail(NotEnoughPeersError, "test_sharemap_nopeers",
-                            p._query_peers,
-                            (new_seqnum, new_root_hash, new_seqnum),
-                            total_shares)
+                            p._query_peers, total_shares)
         return d
 
-    def setup_for_write(self, num_peers, total_shares):
-        c, p = self.setup_for_sharemap(num_peers)
+    def test_write(self):
+        total_shares = 10
+        c, p = self.setup_for_sharemap(20)
+        p._privkey = FakePrivKey(0)
+        p._encprivkey = "encprivkey"
+        p._pubkey = FakePubKey(0)
         # make some fake shares
         CONTENTS = "some initial contents"
         shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
-        d = defer.maybeDeferred(p._generate_shares,
-                                (shares_and_ids,
-                                 3, total_shares,
-                                 21, # segsize
-                                 len(CONTENTS),
-                                 "IV"*8),
-                                3, # seqnum
-                                FakePrivKey(0), "encprivkey", FakePubKey(0),
-                                )
-        return d, p
-
-    def test_write(self):
-        total_shares = 10
-        d, p = self.setup_for_write(20, total_shares)
-        d.addCallback(p._query_peers, total_shares)
+        d = defer.maybeDeferred(p._query_peers, total_shares)
         IV = "IV"*8
+        d.addCallback(lambda target_info:
+                      p._generate_shares( (shares_and_ids,
+                                           3, total_shares,
+                                           21, # segsize
+                                           len(CONTENTS),
+                                           target_info),
+                                          3, # seqnum
+                                          IV))
         d.addCallback(p._send_shares, IV)
         def _done((surprised, dispatch_map)):
             self.failIf(surprised, "surprised!")
@@ -387,7 +382,6 @@ class Publish(unittest.TestCase):
         d.addCallback(_done)
         return d
 
-del Publish # gotta run, will fix this in a few hours
 
 class FakePubKey:
     def __init__(self, count):
index fa990b7a3d3ab7bcb7fb10adf020f415bcea8951..a972b22e751822239a2d70a4a048826ce0e72e5c 100644 (file)
@@ -253,3 +253,6 @@ from foolscap import base32
 def nodeid_b2a(nodeid):
     # we display nodeids using the same base32 alphabet that Foolscap uses
     return base32.encode(nodeid)
+
+def shortnodeid_b2a(nodeid):
+    return nodeid_b2a(nodeid)[:8]