From: Brian Warner Date: Tue, 24 Jul 2007 02:31:53 +0000 (-0700) Subject: refactor upload/encode, to split encrypt and encode responsibilities X-Git-Url: https://git.rkrishnan.org/?a=commitdiff_plain;h=e6e9ddc5883b924bfac8c7f48038668eb2bc4479;p=tahoe-lafs%2Ftahoe-lafs.git refactor upload/encode, to split encrypt and encode responsibilities --- diff --git a/src/allmydata/encode.py b/src/allmydata/encode.py index 3290aec7..e1de2eae 100644 --- a/src/allmydata/encode.py +++ b/src/allmydata/encode.py @@ -5,11 +5,11 @@ from twisted.internet import defer from twisted.python import log from allmydata import uri from allmydata.hashtree import HashTree -from allmydata.Crypto.Cipher import AES from allmydata.util import mathutil, hashutil from allmydata.util.assertutil import _assert from allmydata.codec import CRSEncoder -from allmydata.interfaces import IEncoder, IStorageBucketWriter +from allmydata.interfaces import IEncoder, IStorageBucketWriter, \ + IEncryptedUploadable """ @@ -87,21 +87,20 @@ class Encoder(object): self.SHARES_OF_HAPPINESS = happy self.TOTAL_SHARES = n self.uri_extension_data = {} + self._codec = None def set_size(self, size): + assert not self._codec self.file_size = size def set_params(self, encoding_parameters): + assert not self._codec k,d,n = encoding_parameters self.NEEDED_SHARES = k self.SHARES_OF_HAPPINESS = d self.TOTAL_SHARES = n - def set_uploadable(self, uploadable): - self._uploadable = uploadable - - def setup(self): - + def _setup_codec(self): self.num_shares = self.TOTAL_SHARES self.required_shares = self.NEEDED_SHARES self.shares_of_happiness = self.SHARES_OF_HAPPINESS @@ -110,9 +109,9 @@ class Encoder(object): # this must be a multiple of self.required_shares self.segment_size = mathutil.next_multiple(self.segment_size, self.required_shares) - self._setup_codec() - def _setup_codec(self): + # now set up the codec + assert self.segment_size % self.required_shares == 0 self.num_segments = mathutil.div_ceil(self.file_size, self.segment_size) @@ -150,24 +149,48 @@ class Encoder(object): self.required_shares, self.num_shares) data['tail_codec_params'] = self._tail_codec.get_serialized_params() - def get_serialized_params(self): - return self._codec.get_serialized_params() - - def set_encryption_key(self, key): - assert isinstance(key, str) - assert len(key) == 16 # AES-128 - self.key = key - - def get_share_size(self): + def _get_share_size(self): share_size = mathutil.div_ceil(self.file_size, self.required_shares) - overhead = self.compute_overhead() + overhead = self._compute_overhead() return share_size + overhead - def compute_overhead(self): + + def _compute_overhead(self): return 0 - def get_block_size(self): - return self._codec.get_block_size() - def get_num_segments(self): - return self.num_segments + + def set_encrypted_uploadable(self, uploadable): + u = self._uploadable = IEncryptedUploadable(uploadable) + d = u.get_size() + d.addCallback(self.set_size) + d.addCallback(lambda res: self.get_param("serialized_params")) + d.addCallback(u.set_serialized_encoding_parameters) + d.addCallback(lambda res: u.get_storage_index()) + def _done(storage_index): + self._storage_index = storage_index + return self + d.addCallback(_done) + return d + + def get_param(self, name): + if not self._codec: + self._setup_codec() + + if name == "storage_index": + return self._storage_index + elif name == "share_counts": + return (self.required_shares, self.shares_of_happiness, + self.num_shares) + elif name == "num_segments": + return self.num_segments + elif name == "segment_size": + return self.segment_size + elif name == "block_size": + return self._codec.get_block_size() + elif name == "share_size": + return self._get_share_size() + elif name == "serialized_params": + return self._codec.get_serialized_params() + else: + raise KeyError("unknown parameter name '%s'" % name) def set_shareholders(self, landlords): assert isinstance(landlords, dict) @@ -177,12 +200,20 @@ class Encoder(object): def start(self): #paddedsize = self._size + mathutil.pad_size(self._size, self.needed_shares) - self._plaintext_hasher = hashutil.plaintext_hasher() - self._plaintext_hashes = [] + if not self._codec: + self._setup_codec() + self._crypttext_hasher = hashutil.crypttext_hasher() self._crypttext_hashes = [] - self.setup_encryption() - d = defer.succeed(None) + self.segment_num = 0 + self.subshare_hashes = [[] for x in range(self.num_shares)] + # subshare_hashes[i] is a list that will be accumulated and then send + # to landlord[i]. This list contains a hash of each segment_share + # that we sent to that landlord. + self.share_root_hashes = [None] * self.num_shares + + d = defer.maybeDeferred(self._uploadable.set_segment_size, + self.segment_size) for l in self.landlords.values(): d.addCallback(lambda res, l=l: l.start()) @@ -198,7 +229,7 @@ class Encoder(object): d.addCallback(lambda res: self._encode_tail_segment(last_segnum)) d.addCallback(self._send_segment, last_segnum) - d.addCallback(lambda res: self.finish_flat_hashes()) + d.addCallback(lambda res: self.finish_hashing()) d.addCallback(lambda res: self.send_plaintext_hash_tree_to_all_shareholders()) @@ -212,16 +243,6 @@ class Encoder(object): d.addCallbacks(lambda res: self.done(), self.err) return d - def setup_encryption(self): - self.cryptor = AES.new(key=self.key, mode=AES.MODE_CTR, - counterstart="\x00"*16) - self.segment_num = 0 - self.subshare_hashes = [[] for x in range(self.num_shares)] - # subshare_hashes[i] is a list that will be accumulated and then send - # to landlord[i]. This list contains a hash of each segment_share - # that we sent to that landlord. - self.share_root_hashes = [None] * self.num_shares - def _encode_segment(self, segnum): codec = self._codec @@ -241,7 +262,6 @@ class Encoder(object): # of additional shares which can be substituted if the primary ones # are unavailable - plaintext_segment_hasher = hashutil.plaintext_segment_hasher() crypttext_segment_hasher = hashutil.crypttext_segment_hasher() # memory footprint: we only hold a tiny piece of the plaintext at any @@ -252,12 +272,10 @@ class Encoder(object): # footprint to 500KiB at the expense of more hash-tree overhead. d = self._gather_data(self.required_shares, input_piece_size, - plaintext_segment_hasher, crypttext_segment_hasher) def _done(chunks): for c in chunks: assert len(c) == input_piece_size - self._plaintext_hashes.append(plaintext_segment_hasher.digest()) self._crypttext_hashes.append(crypttext_segment_hasher.digest()) # during this call, we hit 5*segsize memory return codec.encode(chunks) @@ -269,11 +287,9 @@ class Encoder(object): codec = self._tail_codec input_piece_size = codec.get_block_size() - plaintext_segment_hasher = hashutil.plaintext_segment_hasher() crypttext_segment_hasher = hashutil.crypttext_segment_hasher() d = self._gather_data(self.required_shares, input_piece_size, - plaintext_segment_hasher, crypttext_segment_hasher, allow_short=True) def _done(chunks): @@ -281,14 +297,13 @@ class Encoder(object): # a short trailing chunk will have been padded by # _gather_data assert len(c) == input_piece_size - self._plaintext_hashes.append(plaintext_segment_hasher.digest()) self._crypttext_hashes.append(crypttext_segment_hasher.digest()) return codec.encode(chunks) d.addCallback(_done) return d def _gather_data(self, num_chunks, input_chunk_size, - plaintext_segment_hasher, crypttext_segment_hasher, + crypttext_segment_hasher, allow_short=False, previous_chunks=[]): """Return a Deferred that will fire when the required number of @@ -299,19 +314,13 @@ class Encoder(object): if not num_chunks: return defer.succeed(previous_chunks) - d = self._uploadable.read(input_chunk_size) + d = self._uploadable.read_encrypted(input_chunk_size) def _got(data): encrypted_pieces = [] length = 0 - # we use data.pop(0) instead of 'for input_piece in data' to save - # memory: each piece is destroyed as soon as we're done with it. while data: - input_piece = data.pop(0) - length += len(input_piece) - plaintext_segment_hasher.update(input_piece) - self._plaintext_hasher.update(input_piece) - encrypted_piece = self.cryptor.encrypt(input_piece) - assert len(encrypted_piece) == len(input_piece) + encrypted_piece = data.pop(0) + length += len(encrypted_piece) crypttext_segment_hasher.update(encrypted_piece) self._crypttext_hasher.update(encrypted_piece) encrypted_pieces.append(encrypted_piece) @@ -331,7 +340,6 @@ class Encoder(object): d.addCallback(_got) d.addCallback(lambda chunks: self._gather_data(num_chunks-1, input_chunk_size, - plaintext_segment_hasher, crypttext_segment_hasher, allow_short, chunks)) return d @@ -403,20 +411,28 @@ class Encoder(object): d0.addErrback(_eatNotEnoughPeersError) return d - def finish_flat_hashes(self): - plaintext_hash = self._plaintext_hasher.digest() + def finish_hashing(self): crypttext_hash = self._crypttext_hasher.digest() - self.uri_extension_data["plaintext_hash"] = plaintext_hash self.uri_extension_data["crypttext_hash"] = crypttext_hash + u = self._uploadable + d = u.get_plaintext_hash() + def _got(plaintext_hash): + self.uri_extension_data["plaintext_hash"] = plaintext_hash + return u.get_plaintext_segment_hashtree_nodes(self.num_segments) + d.addCallback(_got) + def _got_hashtree_nodes(t): + self.uri_extension_data["plaintext_root_hash"] = t[0] + self._plaintext_hashtree_nodes = t + d.addCallback(_got_hashtree_nodes) + return d def send_plaintext_hash_tree_to_all_shareholders(self): log.msg("%s sending plaintext hash tree" % self) - t = HashTree(self._plaintext_hashes) - all_hashes = list(t) - self.uri_extension_data["plaintext_root_hash"] = t[0] dl = [] for shareid in self.landlords.keys(): - dl.append(self.send_plaintext_hash_tree(shareid, all_hashes)) + d = self.send_plaintext_hash_tree(shareid, + self._plaintext_hashtree_nodes) + dl.append(d) return self._gather_responses(dl) def send_plaintext_hash_tree(self, shareid, all_hashes): @@ -528,7 +544,8 @@ class Encoder(object): def done(self): log.msg("%s: upload done" % self) - return self.uri_extension_hash + return (self.uri_extension_hash, self.required_shares, + self.num_shares, self.file_size) def err(self, f): log.msg("%s: upload failed: %s" % (self, f)) # UNUSUAL diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py index cffc7649..5e0dce64 100644 --- a/src/allmydata/interfaces.py +++ b/src/allmydata/interfaces.py @@ -588,60 +588,110 @@ class ICodecDecoder(Interface): """ class IEncoder(Interface): - """I take a file-like object that provides a sequence of bytes and a list - of shareholders, then encrypt, encode, hash, and deliver shares to those - shareholders. I will compute all the necessary Merkle hash trees that are - necessary to validate the data that eventually comes back from the - shareholders. I provide the root hash of the hash tree, and the encoding - parameters, both of which must be included in the URI. + """I take an object that provides IEncryptedUploadable, which provides + encrypted data, and a list of shareholders. I then encode, hash, and + deliver shares to those shareholders. I will compute all the necessary + Merkle hash trees that are necessary to validate the crypttext that + eventually comes back from the shareholders. I provide the URI Extension + Block Hash, and the encoding parameters, both of which must be included + in the URI. I do not choose shareholders, that is left to the IUploader. I must be given a dict of RemoteReferences to storage buckets that are ready and willing to receive data. """ - def setup(infile): - """I take a file-like object (providing seek, tell, and read) from - which all the plaintext data that is to be uploaded can be read. I - will seek to the beginning of the file before reading any data. - setup() must be called before making any other calls, in particular - before calling get_reservation_size(). + def set_size(size): + """Specify the number of bytes that will be encoded. This must be + peformed before get_serialized_params() can be called. """ + def set_params(params): + """Override the default encoding parameters. 'params' is a tuple of + (k,d,n), where 'k' is the number of required shares, 'd' is the + shares_of_happiness, and 'n' is the total number of shares that will + be created. - def get_share_size(): - """I return the size of the data that will be stored on each - shareholder. This is aggregate amount of data that will be sent to - the shareholder, summed over all the put_block() calls I will ever - make. + Encoding parameters can be set in three ways. 1: The Encoder class + provides defaults (25/75/100). 2: the Encoder can be constructed with + an 'options' dictionary, in which the + needed_and_happy_and_total_shares' key can be a (k,d,n) tuple. 3: + set_params((k,d,n)) can be called. - TODO: this might also include some amount of overhead, like the size - of all the hashes. We need to decide whether this is useful or not. + If you intend to use set_params(), you must call it before + get_share_size or get_param are called. + """ + + def set_encrypted_uploadable(u): + """Provide a source of encrypted upload data. 'u' must implement + IEncryptedUploadable. + + When this is called, the IEncryptedUploadable will be queried for its + length and the storage_index that should be used. - It is useful to determine this size before asking potential - shareholders whether they will grant a lease or not, since their - answers will depend upon how much space we need. + This returns a Deferred that fires with this Encoder instance. + + This must be performed before start() can be called. """ - def get_block_size(): # TODO: can we avoid exposing this? - """I return the size of the individual blocks that will be delivered - to a shareholder's put_block() method. By knowing this, the - shareholder will be able to keep all blocks in a single file and - still provide random access when reading them. + def get_param(name): + """Return an encoding parameter, by name. + + 'storage_index': return a string with the (16-byte truncated SHA-256 + hash) storage index to which these shares should be + pushed. + + 'share_counts': return a tuple describing how many shares are used: + (needed_shares, shares_of_happiness, total_shares) + + 'num_segments': return an int with the number of segments that + will be encoded. + + 'segment_size': return an int with the size of each segment. + + 'block_size': return the size of the individual blocks that will + be delivered to a shareholder's put_block() method. By + knowing this, the shareholder will be able to keep all + blocks in a single file and still provide random access + when reading them. # TODO: can we avoid exposing this? + + 'share_size': an int with the size of the data that will be stored + on each shareholder. This is aggregate amount of data + that will be sent to the shareholder, summed over all + the put_block() calls I will ever make. It is useful to + determine this size before asking potential + shareholders whether they will grant a lease or not, + since their answers will depend upon how much space we + need. TODO: this might also include some amount of + overhead, like the size of all the hashes. We need to + decide whether this is useful or not. + + 'serialized_params': a string with a concise description of the + codec name and its parameters. This may be passed + into the IUploadable to let it make sure that + the same file encoded with different parameters + will result in different storage indexes. + + Once this is called, set_size() and set_params() may not be called. """ def set_shareholders(shareholders): - """I take a dictionary that maps share identifiers (small integers, - starting at 0) to RemoteReferences that provide RIBucketWriter. This - must be called before start(). - """ + """Tell the encoder where to put the encoded shares. 'shareholders' + must be a dictionary that maps share number (an integer ranging from + 0 to n-1) to an instance that provides IStorageBucketWriter. This + must be performed before start() can be called.""" def start(): - """I start the upload. This process involves reading data from the - input file, encrypting it, encoding the pieces, uploading the shares + """Begin the encode/upload process. This involves reading encrypted + data from the IEncryptedUploadable, encoding it, uploading the shares to the shareholders, then sending the hash trees. - I return a Deferred that fires with the hash of the uri_extension - data block. + set_encrypted_uploadable() and set_shareholders() must be called + before this can be invoked. + + This returns a Deferred that fires with a tuple of + (uri_extension_hash, needed_shares, total_shares, size) when the + upload process is complete. This information, plus the encryption + key, is sufficient to construct the URI. """ class IDecoder(Interface): @@ -712,6 +762,62 @@ class IDownloader(Interface): Returns a Deferred that fires (with the results of target.finish) when the download is finished, or errbacks if something went wrong.""" +class IEncryptedUploadable(Interface): + def get_size(): + """This behaves just like IUploadable.get_size().""" + + def set_serialized_encoding_parameters(serialized_encoding_parameters): + """Tell me what encoding parameters will be used for my data. + + 'serialized_encoding_parameters' is a string which indicates how the + data will be encoded (codec name, blocksize, number of shares). + + I may use this when get_storage_index() is called, to influence the + index that I return. Or, I may just ignore it. + + set_serialized_encoding_parameters() may be called 0 or 1 times. If + called, it must be called before get_storage_index(). + """ + + def get_storage_index(): + """Return a Deferred that fires with a 16-byte storage index. This + value may be influenced by the parameters earlier set by + set_serialized_encoding_parameters(). + """ + + def set_segment_size(segment_size): + """Set the segment size, to allow the IEncryptedUploadable to + accurately create the plaintext segment hash tree. This must be + called before any calls to read_encrypted.""" + + def read_encrypted(length): + """This behaves just like IUploadable.read(), but returns crypttext + instead of plaintext. set_segment_size() must be called before the + first call to read_encrypted().""" + + def get_plaintext_segment_hashtree_nodes(num_segments): + """Get the nodes of a merkle hash tree over the plaintext segments. + + This returns a Deferred which fires with a sequence of hashes. Each + hash is a node of a merkle hash tree, generally obtained from:: + + tuple(HashTree(segment_hashes)) + + 'num_segments' is used to assert that the number of segments that the + IEncryptedUploadable handled matches the number of segments that the + encoder was expecting. + """ + + def get_plaintext_hash(): + """Get the hash of the whole plaintext. + + This returns a Deferred which fires with a tagged SHA-256 hash of the + whole plaintext, obtained from hashutil.plaintext_hash(data). + """ + + def close(): + """Just like IUploadable.close().""" + class IUploadable(Interface): def get_size(): """Return a Deferred that will fire with the length of the data to be @@ -719,22 +825,36 @@ class IUploadable(Interface): used, to compute encoding parameters. """ - def get_encryption_key(encoding_parameters): + def set_serialized_encoding_parameters(serialized_encoding_parameters): + """Tell me what encoding parameters will be used for my data. + + 'serialized_encoding_parameters' is a string which indicates how the + data will be encoded (codec name, blocksize, number of shares). + + I may use this when get_encryption_key() is called, to influence the + key that I return. Or, I may just ignore it. + + set_serialized_encoding_parameters() may be called 0 or 1 times. If + called, it must be called before get_encryption_key(). + """ + + def get_encryption_key(): """Return a Deferred that fires with a 16-byte AES key. This key will be used to encrypt the data. The key will also be hashed to derive - the StorageIndex. 'encoding_parameters' is a string which indicates - how the data will be encoded (codec name, blocksize, number of - shares): Uploadables may wish to use these parameters while computing - the encryption key. + the StorageIndex. Uploadables which want to achieve convergence should hash their file - contents and the encoding_parameters to form the key (which of course - requires a full pass over the data). Uploadables can use the - upload.ConvergentUploadMixin class to achieve this automatically. + contents and the serialized_encoding_parameters to form the key + (which of course requires a full pass over the data). Uploadables can + use the upload.ConvergentUploadMixin class to achieve this + automatically. Uploadables which do not care about convergence (or do not wish to make multiple passes over the data) can simply return a strongly-random 16 byte string. + + get_encryption_key() may be called multiple times: the IUploadable is + required to return the same value each time. """ def read(length): diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index 16b1a3a1..f2303374 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -5,6 +5,7 @@ from twisted.internet import defer from twisted.python.failure import Failure from allmydata import encode, upload, download, hashtree, uri from allmydata.util import hashutil +from allmydata.util.assertutil import _assert from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader class LostPeerError(Exception): @@ -151,24 +152,34 @@ class Encode(unittest.TestCase): # force use of multiple segments options = {"max_segment_size": max_segment_size} e = encode.Encoder(options) - e.set_size(datalen) - e.set_uploadable(upload.Data(data)) - nonkey = "\x00" * 16 - e.set_encryption_key(nonkey) - e.setup() - assert e.num_shares == NUM_SHARES # else we'll be completely confused - assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size - shareholders = {} + u = upload.Data(data) + eu = upload.EncryptAnUploadable(u) + d = e.set_encrypted_uploadable(eu) + all_shareholders = [] - for shnum in range(NUM_SHARES): - peer = FakeBucketWriterProxy() - shareholders[shnum] = peer - all_shareholders.append(peer) - e.set_shareholders(shareholders) - d = e.start() - def _check(roothash): - self.failUnless(isinstance(roothash, str)) - self.failUnlessEqual(len(roothash), 32) + def _ready(res): + k,happy,n = e.get_param("share_counts") + _assert(n == NUM_SHARES) # else we'll be completely confused + numsegs = e.get_param("num_segments") + _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS) + segsize = e.get_param("segment_size") + _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize, + NUM_SEGMENTS, segsize, + (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize) + + shareholders = {} + for shnum in range(NUM_SHARES): + peer = FakeBucketWriterProxy() + shareholders[shnum] = peer + all_shareholders.append(peer) + e.set_shareholders(shareholders) + return e.start() + d.addCallback(_ready) + + def _check(res): + (uri_extension_hash, required_shares, num_shares, file_size) = res + self.failUnless(isinstance(uri_extension_hash, str)) + self.failUnlessEqual(len(uri_extension_hash), 32) for i,peer in enumerate(all_shareholders): self.failUnless(peer.closed) self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS) @@ -278,43 +289,45 @@ class Roundtrip(unittest.TestCase): options = {"max_segment_size": max_segment_size, "needed_and_happy_and_total_shares": k_and_happy_and_n} e = encode.Encoder(options) - e.set_size(len(data)) - e.set_uploadable(upload.Data(data)) - nonkey = "\x00" * 16 - e.set_encryption_key(nonkey) - e.setup() - assert e.num_shares == NUM_SHARES # else we'll be completely confused + u = upload.Data(data) + eu = upload.EncryptAnUploadable(u) + d = e.set_encrypted_uploadable(eu) shareholders = {} - all_peers = [] - for shnum in range(NUM_SHARES): - mode = bucket_modes.get(shnum, "good") - peer = FakeBucketWriterProxy(mode) - shareholders[shnum] = peer - e.set_shareholders(shareholders) - - d = e.start() - def _sent(uri_extension_hash): - return (uri_extension_hash, e, shareholders) + def _ready(res): + k,happy,n = e.get_param("share_counts") + assert n == NUM_SHARES # else we'll be completely confused + all_peers = [] + for shnum in range(NUM_SHARES): + mode = bucket_modes.get(shnum, "good") + peer = FakeBucketWriterProxy(mode) + shareholders[shnum] = peer + e.set_shareholders(shareholders) + return e.start() + d.addCallback(_ready) + def _sent(res): + d1 = u.get_encryption_key() + d1.addCallback(lambda key: (res, key, shareholders)) + return d1 d.addCallback(_sent) return d - def recover(self, (uri_extension_hash, e, shareholders), AVAILABLE_SHARES, + def recover(self, (res, key, shareholders), AVAILABLE_SHARES, recover_mode): - key = e.key + (uri_extension_hash, required_shares, num_shares, file_size) = res if "corrupt_key" in recover_mode: # we corrupt the key, so that the decrypted data is corrupted and # will fail the plaintext hash check. Since we're manually # attaching shareholders, the fact that the storage index is also # corrupted doesn't matter. - key = flip_bit(e.key) + key = flip_bit(key) u = uri.CHKFileURI(key=key, uri_extension_hash=uri_extension_hash, - needed_shares=e.required_shares, - total_shares=e.num_shares, - size=e.file_size) + needed_shares=required_shares, + total_shares=num_shares, + size=file_size) URI = u.to_string() client = None @@ -551,7 +564,7 @@ class Roundtrip(unittest.TestCase): recover_mode=("corrupt_key")) def _done(res): self.failUnless(isinstance(res, Failure)) - self.failUnless(res.check(hashtree.BadHashError)) + self.failUnless(res.check(hashtree.BadHashError), res) d.addBoth(_done) return d diff --git a/src/allmydata/upload.py b/src/allmydata/upload.py index 7f9ee0ab..65bcd2dc 100644 --- a/src/allmydata/upload.py +++ b/src/allmydata/upload.py @@ -6,9 +6,10 @@ from twisted.internet import defer from twisted.application import service from foolscap import Referenceable -from allmydata.util import idlib, hashutil +from allmydata.util import hashutil from allmydata import encode, storage, hashtree, uri -from allmydata.interfaces import IUploadable, IUploader +from allmydata.interfaces import IUploadable, IUploader, IEncryptedUploadable +from allmydata.Crypto.Cipher import AES from cStringIO import StringIO import collections, random @@ -238,75 +239,172 @@ class Tahoe3PeerSelector: self.usable_peers.remove(peer) +class EncryptAnUploadable: + """This is a wrapper that takes an IUploadable and provides + IEncryptedUploadable.""" + implements(IEncryptedUploadable) + + def __init__(self, original): + self.original = original + self._encryptor = None + self._plaintext_hasher = hashutil.plaintext_hasher() + self._plaintext_segment_hasher = None + self._plaintext_segment_hashes = [] + self._params = None + + def get_size(self): + return self.original.get_size() + + def set_serialized_encoding_parameters(self, params): + self._params = params + + def _get_encryptor(self): + if self._encryptor: + return defer.succeed(self._encryptor) + + if self._params is not None: + self.original.set_serialized_encoding_parameters(self._params) + + d = self.original.get_encryption_key() + def _got(key): + e = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16) + self._encryptor = e + + storage_index = hashutil.storage_index_chk_hash(key) + assert isinstance(storage_index, str) + # There's no point to having the SI be longer than the key, so we + # specify that it is truncated to the same 128 bits as the AES key. + assert len(storage_index) == 16 # SHA-256 truncated to 128b + self._storage_index = storage_index + + return e + d.addCallback(_got) + return d + + def get_storage_index(self): + d = self._get_encryptor() + d.addCallback(lambda res: self._storage_index) + return d + + def set_segment_size(self, segsize): + self._segment_size = segsize + + def _get_segment_hasher(self): + p = self._plaintext_segment_hasher + if p: + left = self._segment_size - self._plaintext_segment_hashed_bytes + return p, left + p = hashutil.plaintext_segment_hasher() + self._plaintext_segment_hasher = p + self._plaintext_segment_hashed_bytes = 0 + return p, self._segment_size + + def _update_segment_hash(self, chunk): + offset = 0 + while offset < len(chunk): + p, segment_left = self._get_segment_hasher() + chunk_left = len(chunk) - offset + this_segment = min(chunk_left, segment_left) + p.update(chunk[offset:offset+this_segment]) + self._plaintext_segment_hashed_bytes += this_segment + + if self._plaintext_segment_hashed_bytes == self._segment_size: + # we've filled this segment + self._plaintext_segment_hashes.append(p.digest()) + self._plaintext_segment_hasher = None + + offset += this_segment + + def read_encrypted(self, length): + d = self._get_encryptor() + d.addCallback(lambda res: self.original.read(length)) + def _got(data): + assert isinstance(data, (tuple, list)), type(data) + data = list(data) + cryptdata = [] + # we use data.pop(0) instead of 'for chunk in data' to save + # memory: each chunk is destroyed as soon as we're done with it. + while data: + chunk = data.pop(0) + self._plaintext_hasher.update(chunk) + self._update_segment_hash(chunk) + cryptdata.append(self._encryptor.encrypt(chunk)) + del chunk + return cryptdata + d.addCallback(_got) + return d + + def get_plaintext_segment_hashtree_nodes(self, num_segments): + if len(self._plaintext_segment_hashes) < num_segments: + # close out the last one + assert len(self._plaintext_segment_hashes) == num_segments-1 + p, segment_left = self._get_segment_hasher() + self._plaintext_segment_hashes.append(p.digest()) + del self._plaintext_segment_hasher + assert len(self._plaintext_segment_hashes) == num_segments + ht = hashtree.HashTree(self._plaintext_segment_hashes) + return defer.succeed(list(ht)) + + def get_plaintext_hash(self): + h = self._plaintext_hasher.digest() + return defer.succeed(h) + + class CHKUploader: peer_selector_class = Tahoe3PeerSelector - def __init__(self, client, uploadable, options={}): + def __init__(self, client, options={}): self._client = client - self._uploadable = IUploadable(uploadable) self._options = options def set_params(self, encoding_parameters): self._encoding_parameters = encoding_parameters - needed_shares, shares_of_happiness, total_shares = encoding_parameters - self.needed_shares = needed_shares - self.shares_of_happiness = shares_of_happiness - self.total_shares = total_shares - - def start(self): + def start(self, uploadable): """Start uploading the file. This method returns a Deferred that will fire with the URI (a string).""" - log.msg("starting upload of %s" % self._uploadable) + uploadable = IUploadable(uploadable) + log.msg("starting upload of %s" % uploadable) + + eu = EncryptAnUploadable(uploadable) + d = self.start_encrypted(eu) + def _uploaded(res): + d1 = uploadable.get_encryption_key() + d1.addCallback(lambda key: self._compute_uri(res, key)) + return d1 + d.addCallback(_uploaded) + return d - d = self._uploadable.get_size() - d.addCallback(self.setup_encoder) - d.addCallback(self._uploadable.get_encryption_key) - d.addCallback(self.setup_keys) + def start_encrypted(self, encrypted): + eu = IEncryptedUploadable(encrypted) + + e = encode.Encoder(self._options) + e.set_params(self._encoding_parameters) + d = e.set_encrypted_uploadable(eu) d.addCallback(self.locate_all_shareholders) - d.addCallback(self.set_shareholders) - d.addCallback(lambda res: self._encoder.start()) - d.addCallback(self._compute_uri) + d.addCallback(self.set_shareholders, e) + d.addCallback(lambda res: e.start()) + # this fires with the uri_extension_hash and other data return d - def setup_encoder(self, size): - self._size = size - self._encoder = encode.Encoder(self._options) - self._encoder.set_size(size) - self._encoder.set_params(self._encoding_parameters) - self._encoder.set_uploadable(self._uploadable) - self._encoder.setup() - return self._encoder.get_serialized_params() - - def setup_keys(self, key): - assert isinstance(key, str) - assert len(key) == 16 # AES-128 - self._encryption_key = key - self._encoder.set_encryption_key(key) - storage_index = hashutil.storage_index_chk_hash(key) - assert isinstance(storage_index, str) - # There's no point to having the SI be longer than the key, so we - # specify that it is truncated to the same 128 bits as the AES key. - assert len(storage_index) == 16 # SHA-256 truncated to 128b - self._storage_index = storage_index - log.msg(" upload storage_index is [%s]" % (idlib.b2a(storage_index,))) - - - def locate_all_shareholders(self, ignored=None): + def locate_all_shareholders(self, encoder): peer_selector = self.peer_selector_class() - share_size = self._encoder.get_share_size() - block_size = self._encoder.get_block_size() - num_segments = self._encoder.get_num_segments() + + storage_index = encoder.get_param("storage_index") + share_size = encoder.get_param("share_size") + block_size = encoder.get_param("block_size") + num_segments = encoder.get_param("num_segments") + k,desired,n = encoder.get_param("share_counts") + gs = peer_selector.get_shareholders - d = gs(self._client, - self._storage_index, share_size, block_size, - num_segments, self.total_shares, self.shares_of_happiness) + d = gs(self._client, storage_index, share_size, block_size, + num_segments, n, desired) return d - def set_shareholders(self, used_peers): + def set_shareholders(self, used_peers, encoder): """ @param used_peers: a sequence of PeerTracker objects """ @@ -317,16 +415,17 @@ class CHKUploader: for peer in used_peers: buckets.update(peer.buckets) assert len(buckets) == sum([len(peer.buckets) for peer in used_peers]) - self._encoder.set_shareholders(buckets) + encoder.set_shareholders(buckets) - def _compute_uri(self, uri_extension_hash): - u = uri.CHKFileURI(key=self._encryption_key, + def _compute_uri(self, (uri_extension_hash, + needed_shares, total_shares, size), + key): + u = uri.CHKFileURI(key=key, uri_extension_hash=uri_extension_hash, - needed_shares=self.needed_shares, - total_shares=self.total_shares, - size=self._size, + needed_shares=needed_shares, + total_shares=total_shares, + size=size, ) - assert u.storage_index == self._storage_index return u.to_string() def read_this_many_bytes(uploadable, size, prepend_data=[]): @@ -348,17 +447,17 @@ def read_this_many_bytes(uploadable, size, prepend_data=[]): class LiteralUploader: - def __init__(self, client, uploadable, options={}): + def __init__(self, client, options={}): self._client = client - self._uploadable = IUploadable(uploadable) self._options = options def set_params(self, encoding_parameters): pass - def start(self): - d = self._uploadable.get_size() - d.addCallback(lambda size: read_this_many_bytes(self._uploadable, size)) + def start(self, uploadable): + uploadable = IUploadable(uploadable) + d = uploadable.get_size() + d.addCallback(lambda size: read_this_many_bytes(uploadable, size)) d.addCallback(lambda data: uri.LiteralFileURI("".join(data))) d.addCallback(lambda u: u.to_string()) return d @@ -370,25 +469,40 @@ class LiteralUploader: class ConvergentUploadMixin: # to use this, the class it is mixed in to must have a seekable # filehandle named self._filehandle - - def get_encryption_key(self, encoding_parameters): - f = self._filehandle - enckey_hasher = hashutil.key_hasher() - #enckey_hasher.update(encoding_parameters) # TODO - f.seek(0) - BLOCKSIZE = 64*1024 - while True: - data = f.read(BLOCKSIZE) - if not data: - break - enckey_hasher.update(data) - enckey = enckey_hasher.digest()[:16] - f.seek(0) - return defer.succeed(enckey) + _params = None + _key = None + + def set_serialized_encoding_parameters(self, params): + self._params = params + # ignored for now + + def get_encryption_key(self): + if self._key is None: + f = self._filehandle + enckey_hasher = hashutil.key_hasher() + #enckey_hasher.update(encoding_parameters) # TODO + f.seek(0) + BLOCKSIZE = 64*1024 + while True: + data = f.read(BLOCKSIZE) + if not data: + break + enckey_hasher.update(data) + f.seek(0) + self._key = enckey_hasher.digest()[:16] + + return defer.succeed(self._key) class NonConvergentUploadMixin: + _key = None + + def set_serialized_encoding_parameters(self, params): + pass + def get_encryption_key(self, encoding_parameters): - return defer.succeed(os.urandom(16)) + if self._key is None: + self._key = os.urandom(16) + return defer.succeed(self._key) class FileHandle(ConvergentUploadMixin): @@ -446,10 +560,10 @@ class Uploader(service.MultiService): uploader_class = self.uploader_class if size <= self.URI_LIT_SIZE_THRESHOLD: uploader_class = LiteralUploader - uploader = uploader_class(self.parent, uploadable, options) + uploader = uploader_class(self.parent, options) uploader.set_params(self.parent.get_encoding_parameters() or self.DEFAULT_ENCODING_PARAMETERS) - return uploader.start() + return uploader.start(uploadable) d.addCallback(_got_size) def _done(res): uploadable.close()