From: Brian Warner Date: Thu, 7 Jun 2007 07:15:41 +0000 (-0700) Subject: fetch plaintext/crypttext merkle trees during download, but don't check the segments... X-Git-Tag: allmydata-tahoe-0.3.0~24 X-Git-Url: https://git.rkrishnan.org/%5B/%5D%20/uri/index.php?a=commitdiff_plain;h=e04ff3adac47d8e522ef2f694d3cf913b2329d9c;p=tahoe-lafs%2Ftahoe-lafs.git fetch plaintext/crypttext merkle trees during download, but don't check the segments against them yet --- diff --git a/src/allmydata/download.py b/src/allmydata/download.py index 3f5d0d96..db0c1432 100644 --- a/src/allmydata/download.py +++ b/src/allmydata/download.py @@ -21,6 +21,10 @@ class HaveAllPeersError(Exception): class BadThingAHashValue(Exception): pass +class BadPlaintextHashValue(Exception): + pass +class BadCrypttextHashValue(Exception): + pass class Output: def __init__(self, downloadable, key): @@ -30,6 +34,12 @@ class Output: self._verifierid_hasher = sha.new(netstring("allmydata_verifierid_v1")) self._fileid_hasher = sha.new(netstring("allmydata_fileid_v1")) self.length = 0 + self._plaintext_hash_tree = None + self._crypttext_hash_tree = None + + def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree): + self._plaintext_hash_tree = plaintext_hashtree + self._crypttext_hash_tree = crypttext_hashtree def open(self): self.downloadable.open() @@ -251,6 +261,7 @@ class FileDownloader: # now get the thingA block from somebody and validate it d.addCallback(self._obtain_thingA) d.addCallback(self._got_thingA) + d.addCallback(self._get_hashtrees) d.addCallback(self._create_validated_buckets) # once we know that, we can download blocks from everybody d.addCallback(self._download_all_segments) @@ -357,6 +368,50 @@ class FileDownloader: self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares']) self._share_hashtree.set_hashes({0: self._roothash}) + def _get_hashtrees(self, res): + d = self._get_plaintext_hashtrees() + d.addCallback(self._get_crypttext_hashtrees) + d.addCallback(self._setup_hashtrees) + return d + + def _get_plaintext_hashtrees(self): + def _validate_plaintext_hashtree(proposal, bucket): + if proposal[0] != self._thingA_data['plaintext_root_hash']: + msg = ("The copy of the plaintext_root_hash we received from" + " %s was bad" % bucket) + raise BadPlaintextHashValue(msg) + self._plaintext_hashes = proposal + d = self._obtain_validated_thing(None, + self._thingA_sources, + "plaintext_hashes", + "get_plaintext_hashes", (), + _validate_plaintext_hashtree) + return d + + def _get_crypttext_hashtrees(self, res): + def _validate_crypttext_hashtree(proposal, bucket): + if proposal[0] != self._thingA_data['crypttext_root_hash']: + msg = ("The copy of the crypttext_root_hash we received from" + " %s was bad" % bucket) + raise BadCrypttextHashValue(msg) + self._crypttext_hashes = proposal + d = self._obtain_validated_thing(None, + self._thingA_sources, + "crypttext_hashes", + "get_crypttext_hashes", (), + _validate_crypttext_hashtree) + return d + + def _setup_hashtrees(self, res): + plaintext_hashtree = hashtree.IncompleteHashTree(self._total_segments) + plaintext_hashes = dict(list(enumerate(self._plaintext_hashes))) + plaintext_hashtree.set_hashes(plaintext_hashes) + crypttext_hashtree = hashtree.IncompleteHashTree(self._total_segments) + crypttext_hashes = dict(list(enumerate(self._crypttext_hashes))) + crypttext_hashtree.set_hashes(crypttext_hashes) + self._output.setup_hashtrees(plaintext_hashtree, crypttext_hashtree) + + def _create_validated_buckets(self, ignored=None): self._share_vbuckets = {} for sharenum, bucket in self._share_buckets: diff --git a/src/allmydata/encode.py b/src/allmydata/encode.py index a5200459..444d1e45 100644 --- a/src/allmydata/encode.py +++ b/src/allmydata/encode.py @@ -6,11 +6,15 @@ from twisted.python import log from allmydata.hashtree import HashTree, \ block_hash, thingA_hash, plaintext_hash, crypttext_hash from allmydata.Crypto.Cipher import AES +from allmydata.Crypto.Hash import SHA256 from allmydata.util import mathutil, bencode from allmydata.util.assertutil import _assert from allmydata.codec import CRSEncoder from allmydata.interfaces import IEncoder +def netstring(s): + return "%d:%s," % (len(s), s) + """ The goal of the encoder is to turn the original file into a series of @@ -219,14 +223,20 @@ class Encoder(object): # of additional shares which can be substituted if the primary ones # are unavailable + plaintext_hasher = SHA256.new(netstring("allmydata_plaintext_segment_v1")) + crypttext_hasher = SHA256.new(netstring("allmydata_crypttext_segment_v1")) + for i in range(self.required_shares): input_piece = self.infile.read(input_piece_size) # non-tail segments should be the full segment size assert len(input_piece) == input_piece_size - self._plaintext_hashes.append(plaintext_hash(input_piece)) + plaintext_hasher.update(input_piece) encrypted_piece = self.cryptor.encrypt(input_piece) + crypttext_hasher.update(encrypted_piece) chunks.append(encrypted_piece) - self._crypttext_hashes.append(crypttext_hash(encrypted_piece)) + + self._plaintext_hashes.append(plaintext_hasher.digest()) + self._crypttext_hashes.append(crypttext_hasher.digest()) d = codec.encode(chunks) d.addCallback(self._encoded_segment, segnum) return d @@ -236,15 +246,21 @@ class Encoder(object): codec = self._tail_codec input_piece_size = codec.get_block_size() + plaintext_hasher = SHA256.new(netstring("allmydata_plaintext_segment_v1")) + crypttext_hasher = SHA256.new(netstring("allmydata_crypttext_segment_v1")) + for i in range(self.required_shares): input_piece = self.infile.read(input_piece_size) - self._plaintext_hashes.append(plaintext_hash(input_piece)) + plaintext_hasher.update(input_piece) if len(input_piece) < input_piece_size: # padding input_piece += ('\x00' * (input_piece_size - len(input_piece))) encrypted_piece = self.cryptor.encrypt(input_piece) - self._crypttext_hashes.append(crypttext_hash(encrypted_piece)) + crypttext_hasher.update(encrypted_piece) chunks.append(encrypted_piece) + + self._plaintext_hashes.append(plaintext_hash(input_piece)) + self._crypttext_hashes.append(crypttext_hash(encrypted_piece)) d = codec.encode(chunks) d.addCallback(self._encoded_segment, segnum) return d @@ -290,7 +306,7 @@ class Encoder(object): # even more UNUSUAL log.msg(" weird, they weren't in our list of landlords") if len(self.landlords) < self.shares_of_happiness: - msg = "lost too many shareholders during upload" + msg = "lost too many shareholders during upload: %s" % why raise NotEnoughPeersError(msg) log.msg("but we can still continue with %s shares, we'll be happy " "with at least %s" % (len(self.landlords), diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py index 2917ee46..ce2cbe84 100644 --- a/src/allmydata/interfaces.py +++ b/src/allmydata/interfaces.py @@ -87,6 +87,12 @@ class RIBucketReader(RemoteInterface): than the others. """ return ShareData + + def get_plaintext_hashes(): + return ListOf(Hash, maxLength=2**20) + def get_crypttext_hashes(): + return ListOf(Hash, maxLength=2**20) + def get_block_hashes(): return ListOf(Hash, maxLength=2**20) def get_share_hashes(): diff --git a/src/allmydata/storageserver.py b/src/allmydata/storageserver.py index 788c6389..79e7a06f 100644 --- a/src/allmydata/storageserver.py +++ b/src/allmydata/storageserver.py @@ -107,6 +107,11 @@ class BucketReader(Referenceable): f.seek(self.blocksize * blocknum) return f.read(self.blocksize) # this might be short for the last block + def remote_get_plaintext_hashes(self): + return str2l(self._read_file('plaintext_hashes')) + def remote_get_crypttext_hashes(self): + return str2l(self._read_file('crypttext_hashes')) + def remote_get_block_hashes(self): return str2l(self._read_file('blockhashes')) diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index cb6f2e37..3bbfa847 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -105,6 +105,19 @@ class FakeBucketWriter: return self.flip_bit(self.blocks[blocknum]) return self.blocks[blocknum] + def get_plaintext_hashes(self): + if self.mode == "bad plaintexthash": + hashes = self.plaintext_hashes[:] + hashes[1] = self.flip_bit(hashes[1]) + return hashes + return self.plaintext_hashes + def get_crypttext_hashes(self): + if self.mode == "bad crypttexthash": + hashes = self.crypttext_hashes[:] + hashes[1] = self.flip_bit(hashes[1]) + return hashes + return self.crypttext_hashes + def get_block_hashes(self): if self.mode == "bad blockhash": hashes = self.block_hashes[:]