fetch plaintext/crypttext merkle trees during download, but don't check the segments...
authorBrian Warner <warner@lothar.com>
Thu, 7 Jun 2007 07:15:41 +0000 (00:15 -0700)
committerBrian Warner <warner@lothar.com>
Thu, 7 Jun 2007 07:15:41 +0000 (00:15 -0700)
src/allmydata/download.py
src/allmydata/encode.py
src/allmydata/interfaces.py
src/allmydata/storageserver.py
src/allmydata/test/test_encode.py

index 3f5d0d962f3308565761417ce7a01bd02831ee57..db0c1432a9bec3b8917d9a06c9719b00a0911bc4 100644 (file)
@@ -21,6 +21,10 @@ class HaveAllPeersError(Exception):
 
 class BadThingAHashValue(Exception):
     pass
+class BadPlaintextHashValue(Exception):
+    pass
+class BadCrypttextHashValue(Exception):
+    pass
 
 class Output:
     def __init__(self, downloadable, key):
@@ -30,6 +34,12 @@ class Output:
         self._verifierid_hasher = sha.new(netstring("allmydata_verifierid_v1"))
         self._fileid_hasher = sha.new(netstring("allmydata_fileid_v1"))
         self.length = 0
+        self._plaintext_hash_tree = None
+        self._crypttext_hash_tree = None
+
+    def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
+        self._plaintext_hash_tree = plaintext_hashtree
+        self._crypttext_hash_tree = crypttext_hashtree
 
     def open(self):
         self.downloadable.open()
@@ -251,6 +261,7 @@ class FileDownloader:
         # now get the thingA block from somebody and validate it
         d.addCallback(self._obtain_thingA)
         d.addCallback(self._got_thingA)
+        d.addCallback(self._get_hashtrees)
         d.addCallback(self._create_validated_buckets)
         # once we know that, we can download blocks from everybody
         d.addCallback(self._download_all_segments)
@@ -357,6 +368,50 @@ class FileDownloader:
         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
         self._share_hashtree.set_hashes({0: self._roothash})
 
+    def _get_hashtrees(self, res):
+        d = self._get_plaintext_hashtrees()
+        d.addCallback(self._get_crypttext_hashtrees)
+        d.addCallback(self._setup_hashtrees)
+        return d
+
+    def _get_plaintext_hashtrees(self):
+        def _validate_plaintext_hashtree(proposal, bucket):
+            if proposal[0] != self._thingA_data['plaintext_root_hash']:
+                msg = ("The copy of the plaintext_root_hash we received from"
+                       " %s was bad" % bucket)
+                raise BadPlaintextHashValue(msg)
+            self._plaintext_hashes = proposal
+        d = self._obtain_validated_thing(None,
+                                         self._thingA_sources,
+                                         "plaintext_hashes",
+                                         "get_plaintext_hashes", (),
+                                         _validate_plaintext_hashtree)
+        return d
+
+    def _get_crypttext_hashtrees(self, res):
+        def _validate_crypttext_hashtree(proposal, bucket):
+            if proposal[0] != self._thingA_data['crypttext_root_hash']:
+                msg = ("The copy of the crypttext_root_hash we received from"
+                       " %s was bad" % bucket)
+                raise BadCrypttextHashValue(msg)
+            self._crypttext_hashes = proposal
+        d = self._obtain_validated_thing(None,
+                                         self._thingA_sources,
+                                         "crypttext_hashes",
+                                         "get_crypttext_hashes", (),
+                                         _validate_crypttext_hashtree)
+        return d
+
+    def _setup_hashtrees(self, res):
+        plaintext_hashtree = hashtree.IncompleteHashTree(self._total_segments)
+        plaintext_hashes = dict(list(enumerate(self._plaintext_hashes)))
+        plaintext_hashtree.set_hashes(plaintext_hashes)
+        crypttext_hashtree = hashtree.IncompleteHashTree(self._total_segments)
+        crypttext_hashes = dict(list(enumerate(self._crypttext_hashes)))
+        crypttext_hashtree.set_hashes(crypttext_hashes)
+        self._output.setup_hashtrees(plaintext_hashtree, crypttext_hashtree)
+
+
     def _create_validated_buckets(self, ignored=None):
         self._share_vbuckets = {}
         for sharenum, bucket in self._share_buckets:
index a5200459854705f0bfec9d64eed947f94d6bc9a0..444d1e45e220c99c32f0985dbaab35c0f1e1fa0d 100644 (file)
@@ -6,11 +6,15 @@ from twisted.python import log
 from allmydata.hashtree import HashTree, \
      block_hash, thingA_hash, plaintext_hash, crypttext_hash
 from allmydata.Crypto.Cipher import AES
+from allmydata.Crypto.Hash import SHA256
 from allmydata.util import mathutil, bencode
 from allmydata.util.assertutil import _assert
 from allmydata.codec import CRSEncoder
 from allmydata.interfaces import IEncoder
 
+def netstring(s):
+    return "%d:%s," % (len(s), s)
+
 """
 
 The goal of the encoder is to turn the original file into a series of
@@ -219,14 +223,20 @@ class Encoder(object):
         # of additional shares which can be substituted if the primary ones
         # are unavailable
 
+        plaintext_hasher = SHA256.new(netstring("allmydata_plaintext_segment_v1"))
+        crypttext_hasher = SHA256.new(netstring("allmydata_crypttext_segment_v1"))
+
         for i in range(self.required_shares):
             input_piece = self.infile.read(input_piece_size)
             # non-tail segments should be the full segment size
             assert len(input_piece) == input_piece_size
-            self._plaintext_hashes.append(plaintext_hash(input_piece))
+            plaintext_hasher.update(input_piece)
             encrypted_piece = self.cryptor.encrypt(input_piece)
+            crypttext_hasher.update(encrypted_piece)
             chunks.append(encrypted_piece)
-            self._crypttext_hashes.append(crypttext_hash(encrypted_piece))
+
+        self._plaintext_hashes.append(plaintext_hasher.digest())
+        self._crypttext_hashes.append(crypttext_hasher.digest())
         d = codec.encode(chunks)
         d.addCallback(self._encoded_segment, segnum)
         return d
@@ -236,15 +246,21 @@ class Encoder(object):
         codec = self._tail_codec
         input_piece_size = codec.get_block_size()
 
+        plaintext_hasher = SHA256.new(netstring("allmydata_plaintext_segment_v1"))
+        crypttext_hasher = SHA256.new(netstring("allmydata_crypttext_segment_v1"))
+
         for i in range(self.required_shares):
             input_piece = self.infile.read(input_piece_size)
-            self._plaintext_hashes.append(plaintext_hash(input_piece))
+            plaintext_hasher.update(input_piece)
             if len(input_piece) < input_piece_size:
                 # padding
                 input_piece += ('\x00' * (input_piece_size - len(input_piece)))
             encrypted_piece = self.cryptor.encrypt(input_piece)
-            self._crypttext_hashes.append(crypttext_hash(encrypted_piece))
+            crypttext_hasher.update(encrypted_piece)
             chunks.append(encrypted_piece)
+
+        self._plaintext_hashes.append(plaintext_hash(input_piece))
+        self._crypttext_hashes.append(crypttext_hash(encrypted_piece))
         d = codec.encode(chunks)
         d.addCallback(self._encoded_segment, segnum)
         return d
@@ -290,7 +306,7 @@ class Encoder(object):
             # even more UNUSUAL
             log.msg(" weird, they weren't in our list of landlords")
         if len(self.landlords) < self.shares_of_happiness:
-            msg = "lost too many shareholders during upload"
+            msg = "lost too many shareholders during upload: %s" % why
             raise NotEnoughPeersError(msg)
         log.msg("but we can still continue with %s shares, we'll be happy "
                 "with at least %s" % (len(self.landlords),
index 2917ee460ce0e743222528547f67470886dcd763..ce2cbe841e3e9ffc2b2459dd1d129bc2b323eff8 100644 (file)
@@ -87,6 +87,12 @@ class RIBucketReader(RemoteInterface):
         than the others.
         """
         return ShareData
+
+    def get_plaintext_hashes():
+        return ListOf(Hash, maxLength=2**20)
+    def get_crypttext_hashes():
+        return ListOf(Hash, maxLength=2**20)
+
     def get_block_hashes():
         return ListOf(Hash, maxLength=2**20)
     def get_share_hashes():
index 788c638968ef094e9ffc03d5cf1f9acd234c756f..79e7a06f5dbb0a047f2a5c4938d2494fe584e6fe 100644 (file)
@@ -107,6 +107,11 @@ class BucketReader(Referenceable):
         f.seek(self.blocksize * blocknum)
         return f.read(self.blocksize) # this might be short for the last block
 
+    def remote_get_plaintext_hashes(self):
+        return str2l(self._read_file('plaintext_hashes'))
+    def remote_get_crypttext_hashes(self):
+        return str2l(self._read_file('crypttext_hashes'))
+
     def remote_get_block_hashes(self):
         return str2l(self._read_file('blockhashes'))
 
index cb6f2e37cd386dbfc8cf9e77599e4a413e60470f..3bbfa84714d3babd65e5d2cede59526c1e675241 100644 (file)
@@ -105,6 +105,19 @@ class FakeBucketWriter:
             return self.flip_bit(self.blocks[blocknum])
         return self.blocks[blocknum]
 
+    def get_plaintext_hashes(self):
+        if self.mode == "bad plaintexthash":
+            hashes = self.plaintext_hashes[:]
+            hashes[1] = self.flip_bit(hashes[1])
+            return hashes
+        return self.plaintext_hashes
+    def get_crypttext_hashes(self):
+        if self.mode == "bad crypttexthash":
+            hashes = self.crypttext_hashes[:]
+            hashes[1] = self.flip_bit(hashes[1])
+            return hashes
+        return self.crypttext_hashes
+
     def get_block_hashes(self):
         if self.mode == "bad blockhash":
             hashes = self.block_hashes[:]