]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
refactor upload/encode, to split encrypt and encode responsibilities
authorBrian Warner <warner@allmydata.com>
Tue, 24 Jul 2007 02:31:53 +0000 (19:31 -0700)
committerBrian Warner <warner@allmydata.com>
Tue, 24 Jul 2007 02:31:53 +0000 (19:31 -0700)
src/allmydata/encode.py
src/allmydata/interfaces.py
src/allmydata/test/test_encode.py
src/allmydata/upload.py

index 3290aec7b2d8e309f68982767e3ea1b2d1f52d2f..e1de2eaeb3580aaefe09a4671f85664a0f29445a 100644 (file)
@@ -5,11 +5,11 @@ from twisted.internet import defer
 from twisted.python import log
 from allmydata import uri
 from allmydata.hashtree import HashTree
-from allmydata.Crypto.Cipher import AES
 from allmydata.util import mathutil, hashutil
 from allmydata.util.assertutil import _assert
 from allmydata.codec import CRSEncoder
-from allmydata.interfaces import IEncoder, IStorageBucketWriter
+from allmydata.interfaces import IEncoder, IStorageBucketWriter, \
+     IEncryptedUploadable
 
 """
 
@@ -87,21 +87,20 @@ class Encoder(object):
         self.SHARES_OF_HAPPINESS = happy
         self.TOTAL_SHARES = n
         self.uri_extension_data = {}
+        self._codec = None
 
     def set_size(self, size):
+        assert not self._codec
         self.file_size = size
 
     def set_params(self, encoding_parameters):
+        assert not self._codec
         k,d,n = encoding_parameters
         self.NEEDED_SHARES = k
         self.SHARES_OF_HAPPINESS = d
         self.TOTAL_SHARES = n
 
-    def set_uploadable(self, uploadable):
-        self._uploadable = uploadable
-
-    def setup(self):
-
+    def _setup_codec(self):
         self.num_shares = self.TOTAL_SHARES
         self.required_shares = self.NEEDED_SHARES
         self.shares_of_happiness = self.SHARES_OF_HAPPINESS
@@ -110,9 +109,9 @@ class Encoder(object):
         # this must be a multiple of self.required_shares
         self.segment_size = mathutil.next_multiple(self.segment_size,
                                                    self.required_shares)
-        self._setup_codec()
 
-    def _setup_codec(self):
+        # now set up the codec
+
         assert self.segment_size % self.required_shares == 0
         self.num_segments = mathutil.div_ceil(self.file_size,
                                               self.segment_size)
@@ -150,24 +149,48 @@ class Encoder(object):
                                     self.required_shares, self.num_shares)
         data['tail_codec_params'] = self._tail_codec.get_serialized_params()
 
-    def get_serialized_params(self):
-        return self._codec.get_serialized_params()
-
-    def set_encryption_key(self, key):
-        assert isinstance(key, str)
-        assert len(key) == 16 # AES-128
-        self.key = key
-
-    def get_share_size(self):
+    def _get_share_size(self):
         share_size = mathutil.div_ceil(self.file_size, self.required_shares)
-        overhead = self.compute_overhead()
+        overhead = self._compute_overhead()
         return share_size + overhead
-    def compute_overhead(self):
+
+    def _compute_overhead(self):
         return 0
-    def get_block_size(self):
-        return self._codec.get_block_size()
-    def get_num_segments(self):
-        return self.num_segments
+
+    def set_encrypted_uploadable(self, uploadable):
+        u = self._uploadable = IEncryptedUploadable(uploadable)
+        d = u.get_size()
+        d.addCallback(self.set_size)
+        d.addCallback(lambda res: self.get_param("serialized_params"))
+        d.addCallback(u.set_serialized_encoding_parameters)
+        d.addCallback(lambda res: u.get_storage_index())
+        def _done(storage_index):
+            self._storage_index = storage_index
+            return self
+        d.addCallback(_done)
+        return d
+
+    def get_param(self, name):
+        if not self._codec:
+            self._setup_codec()
+
+        if name == "storage_index":
+            return self._storage_index
+        elif name == "share_counts":
+            return (self.required_shares, self.shares_of_happiness,
+                    self.num_shares)
+        elif name == "num_segments":
+            return self.num_segments
+        elif name == "segment_size":
+            return self.segment_size
+        elif name == "block_size":
+            return self._codec.get_block_size()
+        elif name == "share_size":
+            return self._get_share_size()
+        elif name == "serialized_params":
+            return self._codec.get_serialized_params()
+        else:
+            raise KeyError("unknown parameter name '%s'" % name)
 
     def set_shareholders(self, landlords):
         assert isinstance(landlords, dict)
@@ -177,12 +200,20 @@ class Encoder(object):
 
     def start(self):
         #paddedsize = self._size + mathutil.pad_size(self._size, self.needed_shares)
-        self._plaintext_hasher = hashutil.plaintext_hasher()
-        self._plaintext_hashes = []
+        if not self._codec:
+            self._setup_codec()
+
         self._crypttext_hasher = hashutil.crypttext_hasher()
         self._crypttext_hashes = []
-        self.setup_encryption()
-        d = defer.succeed(None)
+        self.segment_num = 0
+        self.subshare_hashes = [[] for x in range(self.num_shares)]
+        # subshare_hashes[i] is a list that will be accumulated and then send
+        # to landlord[i]. This list contains a hash of each segment_share
+        # that we sent to that landlord.
+        self.share_root_hashes = [None] * self.num_shares
+
+        d = defer.maybeDeferred(self._uploadable.set_segment_size,
+                                self.segment_size)
 
         for l in self.landlords.values():
             d.addCallback(lambda res, l=l: l.start())
@@ -198,7 +229,7 @@ class Encoder(object):
         d.addCallback(lambda res: self._encode_tail_segment(last_segnum))
         d.addCallback(self._send_segment, last_segnum)
 
-        d.addCallback(lambda res: self.finish_flat_hashes())
+        d.addCallback(lambda res: self.finish_hashing())
 
         d.addCallback(lambda res:
                       self.send_plaintext_hash_tree_to_all_shareholders())
@@ -212,16 +243,6 @@ class Encoder(object):
         d.addCallbacks(lambda res: self.done(), self.err)
         return d
 
-    def setup_encryption(self):
-        self.cryptor = AES.new(key=self.key, mode=AES.MODE_CTR,
-                               counterstart="\x00"*16)
-        self.segment_num = 0
-        self.subshare_hashes = [[] for x in range(self.num_shares)]
-        # subshare_hashes[i] is a list that will be accumulated and then send
-        # to landlord[i]. This list contains a hash of each segment_share
-        # that we sent to that landlord.
-        self.share_root_hashes = [None] * self.num_shares
-
     def _encode_segment(self, segnum):
         codec = self._codec
 
@@ -241,7 +262,6 @@ class Encoder(object):
         # of additional shares which can be substituted if the primary ones
         # are unavailable
 
-        plaintext_segment_hasher = hashutil.plaintext_segment_hasher()
         crypttext_segment_hasher = hashutil.crypttext_segment_hasher()
 
         # memory footprint: we only hold a tiny piece of the plaintext at any
@@ -252,12 +272,10 @@ class Encoder(object):
         # footprint to 500KiB at the expense of more hash-tree overhead.
 
         d = self._gather_data(self.required_shares, input_piece_size,
-                              plaintext_segment_hasher,
                               crypttext_segment_hasher)
         def _done(chunks):
             for c in chunks:
                 assert len(c) == input_piece_size
-            self._plaintext_hashes.append(plaintext_segment_hasher.digest())
             self._crypttext_hashes.append(crypttext_segment_hasher.digest())
             # during this call, we hit 5*segsize memory
             return codec.encode(chunks)
@@ -269,11 +287,9 @@ class Encoder(object):
         codec = self._tail_codec
         input_piece_size = codec.get_block_size()
 
-        plaintext_segment_hasher = hashutil.plaintext_segment_hasher()
         crypttext_segment_hasher = hashutil.crypttext_segment_hasher()
 
         d = self._gather_data(self.required_shares, input_piece_size,
-                              plaintext_segment_hasher,
                               crypttext_segment_hasher,
                               allow_short=True)
         def _done(chunks):
@@ -281,14 +297,13 @@ class Encoder(object):
                 # a short trailing chunk will have been padded by
                 # _gather_data
                 assert len(c) == input_piece_size
-            self._plaintext_hashes.append(plaintext_segment_hasher.digest())
             self._crypttext_hashes.append(crypttext_segment_hasher.digest())
             return codec.encode(chunks)
         d.addCallback(_done)
         return d
 
     def _gather_data(self, num_chunks, input_chunk_size,
-                     plaintext_segment_hasher, crypttext_segment_hasher,
+                     crypttext_segment_hasher,
                      allow_short=False,
                      previous_chunks=[]):
         """Return a Deferred that will fire when the required number of
@@ -299,19 +314,13 @@ class Encoder(object):
         if not num_chunks:
             return defer.succeed(previous_chunks)
 
-        d = self._uploadable.read(input_chunk_size)
+        d = self._uploadable.read_encrypted(input_chunk_size)
         def _got(data):
             encrypted_pieces = []
             length = 0
-            # we use data.pop(0) instead of 'for input_piece in data' to save
-            # memory: each piece is destroyed as soon as we're done with it.
             while data:
-                input_piece = data.pop(0)
-                length += len(input_piece)
-                plaintext_segment_hasher.update(input_piece)
-                self._plaintext_hasher.update(input_piece)
-                encrypted_piece = self.cryptor.encrypt(input_piece)
-                assert len(encrypted_piece) == len(input_piece)
+                encrypted_piece = data.pop(0)
+                length += len(encrypted_piece)
                 crypttext_segment_hasher.update(encrypted_piece)
                 self._crypttext_hasher.update(encrypted_piece)
                 encrypted_pieces.append(encrypted_piece)
@@ -331,7 +340,6 @@ class Encoder(object):
         d.addCallback(_got)
         d.addCallback(lambda chunks:
                       self._gather_data(num_chunks-1, input_chunk_size,
-                                        plaintext_segment_hasher,
                                         crypttext_segment_hasher,
                                         allow_short, chunks))
         return d
@@ -403,20 +411,28 @@ class Encoder(object):
             d0.addErrback(_eatNotEnoughPeersError)
         return d
 
-    def finish_flat_hashes(self):
-        plaintext_hash = self._plaintext_hasher.digest()
+    def finish_hashing(self):
         crypttext_hash = self._crypttext_hasher.digest()
-        self.uri_extension_data["plaintext_hash"] = plaintext_hash
         self.uri_extension_data["crypttext_hash"] = crypttext_hash
+        u = self._uploadable
+        d = u.get_plaintext_hash()
+        def _got(plaintext_hash):
+            self.uri_extension_data["plaintext_hash"] = plaintext_hash
+            return u.get_plaintext_segment_hashtree_nodes(self.num_segments)
+        d.addCallback(_got)
+        def _got_hashtree_nodes(t):
+            self.uri_extension_data["plaintext_root_hash"] = t[0]
+            self._plaintext_hashtree_nodes = t
+        d.addCallback(_got_hashtree_nodes)
+        return d
 
     def send_plaintext_hash_tree_to_all_shareholders(self):
         log.msg("%s sending plaintext hash tree" % self)
-        t = HashTree(self._plaintext_hashes)
-        all_hashes = list(t)
-        self.uri_extension_data["plaintext_root_hash"] = t[0]
         dl = []
         for shareid in self.landlords.keys():
-            dl.append(self.send_plaintext_hash_tree(shareid, all_hashes))
+            d = self.send_plaintext_hash_tree(shareid,
+                                              self._plaintext_hashtree_nodes)
+            dl.append(d)
         return self._gather_responses(dl)
 
     def send_plaintext_hash_tree(self, shareid, all_hashes):
@@ -528,7 +544,8 @@ class Encoder(object):
 
     def done(self):
         log.msg("%s: upload done" % self)
-        return self.uri_extension_hash
+        return (self.uri_extension_hash, self.required_shares,
+                self.num_shares, self.file_size)
 
     def err(self, f):
         log.msg("%s: upload failed: %s" % (self, f)) # UNUSUAL
index cffc764954b09de7555670119b71454e6e8ef3d1..5e0dce640ec7fdff85ae8626c4c50566384e0f72 100644 (file)
@@ -588,60 +588,110 @@ class ICodecDecoder(Interface):
         """
 
 class IEncoder(Interface):
-    """I take a file-like object that provides a sequence of bytes and a list
-    of shareholders, then encrypt, encode, hash, and deliver shares to those
-    shareholders. I will compute all the necessary Merkle hash trees that are
-    necessary to validate the data that eventually comes back from the
-    shareholders. I provide the root hash of the hash tree, and the encoding
-    parameters, both of which must be included in the URI.
+    """I take an object that provides IEncryptedUploadable, which provides
+    encrypted data, and a list of shareholders. I then encode, hash, and
+    deliver shares to those shareholders. I will compute all the necessary
+    Merkle hash trees that are necessary to validate the crypttext that
+    eventually comes back from the shareholders. I provide the URI Extension
+    Block Hash, and the encoding parameters, both of which must be included
+    in the URI.
 
     I do not choose shareholders, that is left to the IUploader. I must be
     given a dict of RemoteReferences to storage buckets that are ready and
     willing to receive data.
     """
 
-    def setup(infile):
-        """I take a file-like object (providing seek, tell, and read) from
-        which all the plaintext data that is to be uploaded can be read. I
-        will seek to the beginning of the file before reading any data.
-        setup() must be called before making any other calls, in particular
-        before calling get_reservation_size().
+    def set_size(size):
+        """Specify the number of bytes that will be encoded. This must be
+        peformed before get_serialized_params() can be called.
         """
+    def set_params(params):
+        """Override the default encoding parameters. 'params' is a tuple of
+        (k,d,n), where 'k' is the number of required shares, 'd' is the
+        shares_of_happiness, and 'n' is the total number of shares that will
+        be created.
 
-    def get_share_size():
-        """I return the size of the data that will be stored on each
-        shareholder. This is aggregate amount of data that will be sent to
-        the shareholder, summed over all the put_block() calls I will ever
-        make.
+        Encoding parameters can be set in three ways. 1: The Encoder class
+        provides defaults (25/75/100). 2: the Encoder can be constructed with
+        an 'options' dictionary, in which the
+        needed_and_happy_and_total_shares' key can be a (k,d,n) tuple. 3:
+        set_params((k,d,n)) can be called.
 
-        TODO: this might also include some amount of overhead, like the size
-        of all the hashes. We need to decide whether this is useful or not.
+        If you intend to use set_params(), you must call it before
+        get_share_size or get_param are called.
+        """
+
+    def set_encrypted_uploadable(u):
+        """Provide a source of encrypted upload data. 'u' must implement
+        IEncryptedUploadable.
+
+        When this is called, the IEncryptedUploadable will be queried for its
+        length and the storage_index that should be used.
 
-        It is useful to determine this size before asking potential
-        shareholders whether they will grant a lease or not, since their
-        answers will depend upon how much space we need.
+        This returns a Deferred that fires with this Encoder instance.
+
+        This must be performed before start() can be called.
         """
 
-    def get_block_size(): # TODO: can we avoid exposing this?
-        """I return the size of the individual blocks that will be delivered
-        to a shareholder's put_block() method. By knowing this, the
-        shareholder will be able to keep all blocks in a single file and
-        still provide random access when reading them.
+    def get_param(name):
+        """Return an encoding parameter, by name.
+
+        'storage_index': return a string with the (16-byte truncated SHA-256
+                         hash) storage index to which these shares should be
+                         pushed.
+
+        'share_counts': return a tuple describing how many shares are used:
+                        (needed_shares, shares_of_happiness, total_shares)
+
+        'num_segments': return an int with the number of segments that
+                        will be encoded.
+
+        'segment_size': return an int with the size of each segment.
+
+        'block_size': return the size of the individual blocks that will
+                      be delivered to a shareholder's put_block() method. By
+                      knowing this, the shareholder will be able to keep all
+                      blocks in a single file and still provide random access
+                      when reading them. # TODO: can we avoid exposing this?
+
+        'share_size': an int with the size of the data that will be stored
+                      on each shareholder. This is aggregate amount of data
+                      that will be sent to the shareholder, summed over all
+                      the put_block() calls I will ever make. It is useful to
+                      determine this size before asking potential
+                      shareholders whether they will grant a lease or not,
+                      since their answers will depend upon how much space we
+                      need. TODO: this might also include some amount of
+                      overhead, like the size of all the hashes. We need to
+                      decide whether this is useful or not.
+
+        'serialized_params': a string with a concise description of the
+                             codec name and its parameters. This may be passed
+                             into the IUploadable to let it make sure that
+                             the same file encoded with different parameters
+                             will result in different storage indexes.
+
+        Once this is called, set_size() and set_params() may not be called.
         """
 
     def set_shareholders(shareholders):
-        """I take a dictionary that maps share identifiers (small integers,
-        starting at 0) to RemoteReferences that provide RIBucketWriter. This
-        must be called before start().
-        """
+        """Tell the encoder where to put the encoded shares. 'shareholders'
+        must be a dictionary that maps share number (an integer ranging from
+        0 to n-1) to an instance that provides IStorageBucketWriter. This
+        must be performed before start() can be called."""
 
     def start():
-        """I start the upload. This process involves reading data from the
-        input file, encrypting it, encoding the pieces, uploading the shares
+        """Begin the encode/upload process. This involves reading encrypted
+        data from the IEncryptedUploadable, encoding it, uploading the shares
         to the shareholders, then sending the hash trees.
 
-        I return a Deferred that fires with the hash of the uri_extension
-        data block.
+        set_encrypted_uploadable() and set_shareholders() must be called
+        before this can be invoked.
+
+        This returns a Deferred that fires with a tuple of
+        (uri_extension_hash, needed_shares, total_shares, size) when the
+        upload process is complete. This information, plus the encryption
+        key, is sufficient to construct the URI.
         """
 
 class IDecoder(Interface):
@@ -712,6 +762,62 @@ class IDownloader(Interface):
         Returns a Deferred that fires (with the results of target.finish)
         when the download is finished, or errbacks if something went wrong."""
 
+class IEncryptedUploadable(Interface):
+    def get_size():
+        """This behaves just like IUploadable.get_size()."""
+
+    def set_serialized_encoding_parameters(serialized_encoding_parameters):
+        """Tell me what encoding parameters will be used for my data.
+
+        'serialized_encoding_parameters' is a string which indicates how the
+        data will be encoded (codec name, blocksize, number of shares).
+
+        I may use this when get_storage_index() is called, to influence the
+        index that I return. Or, I may just ignore it.
+
+        set_serialized_encoding_parameters() may be called 0 or 1 times. If
+        called, it must be called before get_storage_index().
+        """
+
+    def get_storage_index():
+        """Return a Deferred that fires with a 16-byte storage index. This
+        value may be influenced by the parameters earlier set by
+        set_serialized_encoding_parameters().
+        """
+
+    def set_segment_size(segment_size):
+        """Set the segment size, to allow the IEncryptedUploadable to
+        accurately create the plaintext segment hash tree. This must be
+        called before any calls to read_encrypted."""
+
+    def read_encrypted(length):
+        """This behaves just like IUploadable.read(), but returns crypttext
+        instead of plaintext. set_segment_size() must be called before the
+        first call to read_encrypted()."""
+
+    def get_plaintext_segment_hashtree_nodes(num_segments):
+        """Get the nodes of a merkle hash tree over the plaintext segments.
+
+        This returns a Deferred which fires with a sequence of hashes. Each
+        hash is a node of a merkle hash tree, generally obtained from::
+
+         tuple(HashTree(segment_hashes))
+
+        'num_segments' is used to assert that the number of segments that the
+        IEncryptedUploadable handled matches the number of segments that the
+        encoder was expecting.
+        """
+
+    def get_plaintext_hash():
+        """Get the hash of the whole plaintext.
+
+        This returns a Deferred which fires with a tagged SHA-256 hash of the
+        whole plaintext, obtained from hashutil.plaintext_hash(data).
+        """
+
+    def close():
+        """Just like IUploadable.close()."""
+
 class IUploadable(Interface):
     def get_size():
         """Return a Deferred that will fire with the length of the data to be
@@ -719,22 +825,36 @@ class IUploadable(Interface):
         used, to compute encoding parameters.
         """
 
-    def get_encryption_key(encoding_parameters):
+    def set_serialized_encoding_parameters(serialized_encoding_parameters):
+        """Tell me what encoding parameters will be used for my data.
+
+        'serialized_encoding_parameters' is a string which indicates how the
+        data will be encoded (codec name, blocksize, number of shares).
+
+        I may use this when get_encryption_key() is called, to influence the
+        key that I return. Or, I may just ignore it.
+
+        set_serialized_encoding_parameters() may be called 0 or 1 times. If
+        called, it must be called before get_encryption_key().
+        """
+
+    def get_encryption_key():
         """Return a Deferred that fires with a 16-byte AES key. This key will
         be used to encrypt the data. The key will also be hashed to derive
-        the StorageIndex. 'encoding_parameters' is a string which indicates
-        how the data will be encoded (codec name, blocksize, number of
-        shares): Uploadables may wish to use these parameters while computing
-        the encryption key.
+        the StorageIndex.
 
         Uploadables which want to achieve convergence should hash their file
-        contents and the encoding_parameters to form the key (which of course
-        requires a full pass over the data). Uploadables can use the
-        upload.ConvergentUploadMixin class to achieve this automatically.
+        contents and the serialized_encoding_parameters to form the key
+        (which of course requires a full pass over the data). Uploadables can
+        use the upload.ConvergentUploadMixin class to achieve this
+        automatically.
 
         Uploadables which do not care about convergence (or do not wish to
         make multiple passes over the data) can simply return a
         strongly-random 16 byte string.
+
+        get_encryption_key() may be called multiple times: the IUploadable is
+        required to return the same value each time.
         """
 
     def read(length):
index 16b1a3a1fb16860ef9af59004b904abe2a1901ff..f23033740c17f944b26836db11c37dd5e8478486 100644 (file)
@@ -5,6 +5,7 @@ from twisted.internet import defer
 from twisted.python.failure import Failure
 from allmydata import encode, upload, download, hashtree, uri
 from allmydata.util import hashutil
+from allmydata.util.assertutil import _assert
 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
 
 class LostPeerError(Exception):
@@ -151,24 +152,34 @@ class Encode(unittest.TestCase):
         # force use of multiple segments
         options = {"max_segment_size": max_segment_size}
         e = encode.Encoder(options)
-        e.set_size(datalen)
-        e.set_uploadable(upload.Data(data))
-        nonkey = "\x00" * 16
-        e.set_encryption_key(nonkey)
-        e.setup()
-        assert e.num_shares == NUM_SHARES # else we'll be completely confused
-        assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
-        shareholders = {}
+        u = upload.Data(data)
+        eu = upload.EncryptAnUploadable(u)
+        d = e.set_encrypted_uploadable(eu)
+
         all_shareholders = []
-        for shnum in range(NUM_SHARES):
-            peer = FakeBucketWriterProxy()
-            shareholders[shnum] = peer
-            all_shareholders.append(peer)
-        e.set_shareholders(shareholders)
-        d = e.start()
-        def _check(roothash):
-            self.failUnless(isinstance(roothash, str))
-            self.failUnlessEqual(len(roothash), 32)
+        def _ready(res):
+            k,happy,n = e.get_param("share_counts")
+            _assert(n == NUM_SHARES) # else we'll be completely confused
+            numsegs = e.get_param("num_segments")
+            _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS)
+            segsize = e.get_param("segment_size")
+            _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize,
+                     NUM_SEGMENTS, segsize,
+                     (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize)
+
+            shareholders = {}
+            for shnum in range(NUM_SHARES):
+                peer = FakeBucketWriterProxy()
+                shareholders[shnum] = peer
+                all_shareholders.append(peer)
+            e.set_shareholders(shareholders)
+            return e.start()
+        d.addCallback(_ready)
+
+        def _check(res):
+            (uri_extension_hash, required_shares, num_shares, file_size) = res
+            self.failUnless(isinstance(uri_extension_hash, str))
+            self.failUnlessEqual(len(uri_extension_hash), 32)
             for i,peer in enumerate(all_shareholders):
                 self.failUnless(peer.closed)
                 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
@@ -278,43 +289,45 @@ class Roundtrip(unittest.TestCase):
         options = {"max_segment_size": max_segment_size,
                    "needed_and_happy_and_total_shares": k_and_happy_and_n}
         e = encode.Encoder(options)
-        e.set_size(len(data))
-        e.set_uploadable(upload.Data(data))
-        nonkey = "\x00" * 16
-        e.set_encryption_key(nonkey)
-        e.setup()
-        assert e.num_shares == NUM_SHARES # else we'll be completely confused
+        u = upload.Data(data)
+        eu = upload.EncryptAnUploadable(u)
+        d = e.set_encrypted_uploadable(eu)
 
         shareholders = {}
-        all_peers = []
-        for shnum in range(NUM_SHARES):
-            mode = bucket_modes.get(shnum, "good")
-            peer = FakeBucketWriterProxy(mode)
-            shareholders[shnum] = peer
-        e.set_shareholders(shareholders)
-
-        d = e.start()
-        def _sent(uri_extension_hash):
-            return (uri_extension_hash, e, shareholders)
+        def _ready(res):
+            k,happy,n = e.get_param("share_counts")
+            assert n == NUM_SHARES # else we'll be completely confused
+            all_peers = []
+            for shnum in range(NUM_SHARES):
+                mode = bucket_modes.get(shnum, "good")
+                peer = FakeBucketWriterProxy(mode)
+                shareholders[shnum] = peer
+            e.set_shareholders(shareholders)
+            return e.start()
+        d.addCallback(_ready)
+        def _sent(res):
+            d1 = u.get_encryption_key()
+            d1.addCallback(lambda key: (res, key, shareholders))
+            return d1
         d.addCallback(_sent)
         return d
 
-    def recover(self, (uri_extension_hash, e, shareholders), AVAILABLE_SHARES,
+    def recover(self, (res, key, shareholders), AVAILABLE_SHARES,
                 recover_mode):
-        key = e.key
+        (uri_extension_hash, required_shares, num_shares, file_size) = res
 
         if "corrupt_key" in recover_mode:
             # we corrupt the key, so that the decrypted data is corrupted and
             # will fail the plaintext hash check. Since we're manually
             # attaching shareholders, the fact that the storage index is also
             # corrupted doesn't matter.
-            key = flip_bit(e.key)
+            key = flip_bit(key)
 
         u = uri.CHKFileURI(key=key,
                            uri_extension_hash=uri_extension_hash,
-                           needed_shares=e.required_shares,
-                           total_shares=e.num_shares,
-                           size=e.file_size)
+                           needed_shares=required_shares,
+                           total_shares=num_shares,
+                           size=file_size)
         URI = u.to_string()
 
         client = None
@@ -551,7 +564,7 @@ class Roundtrip(unittest.TestCase):
                                   recover_mode=("corrupt_key"))
         def _done(res):
             self.failUnless(isinstance(res, Failure))
-            self.failUnless(res.check(hashtree.BadHashError))
+            self.failUnless(res.check(hashtree.BadHashError), res)
         d.addBoth(_done)
         return d
 
index 7f9ee0ab6317592e59f5b50d3ad223fad3af7680..65bcd2dc16222974a2b5ee3291c13100c780eb54 100644 (file)
@@ -6,9 +6,10 @@ from twisted.internet import defer
 from twisted.application import service
 from foolscap import Referenceable
 
-from allmydata.util import idlib, hashutil
+from allmydata.util import hashutil
 from allmydata import encode, storage, hashtree, uri
-from allmydata.interfaces import IUploadable, IUploader
+from allmydata.interfaces import IUploadable, IUploader, IEncryptedUploadable
+from allmydata.Crypto.Cipher import AES
 
 from cStringIO import StringIO
 import collections, random
@@ -238,75 +239,172 @@ class Tahoe3PeerSelector:
         self.usable_peers.remove(peer)
 
 
+class EncryptAnUploadable:
+    """This is a wrapper that takes an IUploadable and provides
+    IEncryptedUploadable."""
+    implements(IEncryptedUploadable)
+
+    def __init__(self, original):
+        self.original = original
+        self._encryptor = None
+        self._plaintext_hasher = hashutil.plaintext_hasher()
+        self._plaintext_segment_hasher = None
+        self._plaintext_segment_hashes = []
+        self._params = None
+
+    def get_size(self):
+        return self.original.get_size()
+
+    def set_serialized_encoding_parameters(self, params):
+        self._params = params
+
+    def _get_encryptor(self):
+        if self._encryptor:
+            return defer.succeed(self._encryptor)
+
+        if self._params is not None:
+            self.original.set_serialized_encoding_parameters(self._params)
+
+        d = self.original.get_encryption_key()
+        def _got(key):
+            e = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16)
+            self._encryptor = e
+
+            storage_index = hashutil.storage_index_chk_hash(key)
+            assert isinstance(storage_index, str)
+            # There's no point to having the SI be longer than the key, so we
+            # specify that it is truncated to the same 128 bits as the AES key.
+            assert len(storage_index) == 16  # SHA-256 truncated to 128b
+            self._storage_index = storage_index
+
+            return e
+        d.addCallback(_got)
+        return d
+
+    def get_storage_index(self):
+        d = self._get_encryptor()
+        d.addCallback(lambda res: self._storage_index)
+        return d
+
+    def set_segment_size(self, segsize):
+        self._segment_size = segsize
+
+    def _get_segment_hasher(self):
+        p = self._plaintext_segment_hasher
+        if p:
+            left = self._segment_size - self._plaintext_segment_hashed_bytes
+            return p, left
+        p = hashutil.plaintext_segment_hasher()
+        self._plaintext_segment_hasher = p
+        self._plaintext_segment_hashed_bytes = 0
+        return p, self._segment_size
+
+    def _update_segment_hash(self, chunk):
+        offset = 0
+        while offset < len(chunk):
+            p, segment_left = self._get_segment_hasher()
+            chunk_left = len(chunk) - offset
+            this_segment = min(chunk_left, segment_left)
+            p.update(chunk[offset:offset+this_segment])
+            self._plaintext_segment_hashed_bytes += this_segment
+
+            if self._plaintext_segment_hashed_bytes == self._segment_size:
+                # we've filled this segment
+                self._plaintext_segment_hashes.append(p.digest())
+                self._plaintext_segment_hasher = None
+
+            offset += this_segment
+
+    def read_encrypted(self, length):
+        d = self._get_encryptor()
+        d.addCallback(lambda res: self.original.read(length))
+        def _got(data):
+            assert isinstance(data, (tuple, list)), type(data)
+            data = list(data)
+            cryptdata = []
+            # we use data.pop(0) instead of 'for chunk in data' to save
+            # memory: each chunk is destroyed as soon as we're done with it.
+            while data:
+                chunk = data.pop(0)
+                self._plaintext_hasher.update(chunk)
+                self._update_segment_hash(chunk)
+                cryptdata.append(self._encryptor.encrypt(chunk))
+                del chunk
+            return cryptdata
+        d.addCallback(_got)
+        return d
+
+    def get_plaintext_segment_hashtree_nodes(self, num_segments):
+        if len(self._plaintext_segment_hashes) < num_segments:
+            # close out the last one
+            assert len(self._plaintext_segment_hashes) == num_segments-1
+            p, segment_left = self._get_segment_hasher()
+            self._plaintext_segment_hashes.append(p.digest())
+            del self._plaintext_segment_hasher
+        assert len(self._plaintext_segment_hashes) == num_segments
+        ht = hashtree.HashTree(self._plaintext_segment_hashes)
+        return defer.succeed(list(ht))
+
+    def get_plaintext_hash(self):
+        h = self._plaintext_hasher.digest()
+        return defer.succeed(h)
+
+
 class CHKUploader:
     peer_selector_class = Tahoe3PeerSelector
 
-    def __init__(self, client, uploadable, options={}):
+    def __init__(self, client, options={}):
         self._client = client
-        self._uploadable = IUploadable(uploadable)
         self._options = options
 
     def set_params(self, encoding_parameters):
         self._encoding_parameters = encoding_parameters
 
-        needed_shares, shares_of_happiness, total_shares = encoding_parameters
-        self.needed_shares = needed_shares
-        self.shares_of_happiness = shares_of_happiness
-        self.total_shares = total_shares
-
-    def start(self):
+    def start(self, uploadable):
         """Start uploading the file.
 
         This method returns a Deferred that will fire with the URI (a
         string)."""
 
-        log.msg("starting upload of %s" % self._uploadable)
+        uploadable = IUploadable(uploadable)
+        log.msg("starting upload of %s" % uploadable)
+
+        eu = EncryptAnUploadable(uploadable)
+        d = self.start_encrypted(eu)
+        def _uploaded(res):
+            d1 = uploadable.get_encryption_key()
+            d1.addCallback(lambda key: self._compute_uri(res, key))
+            return d1
+        d.addCallback(_uploaded)
+        return d
 
-        d = self._uploadable.get_size()
-        d.addCallback(self.setup_encoder)
-        d.addCallback(self._uploadable.get_encryption_key)
-        d.addCallback(self.setup_keys)
+    def start_encrypted(self, encrypted):
+        eu = IEncryptedUploadable(encrypted)
+
+        e = encode.Encoder(self._options)
+        e.set_params(self._encoding_parameters)
+        d = e.set_encrypted_uploadable(eu)
         d.addCallback(self.locate_all_shareholders)
-        d.addCallback(self.set_shareholders)
-        d.addCallback(lambda res: self._encoder.start())
-        d.addCallback(self._compute_uri)
+        d.addCallback(self.set_shareholders, e)
+        d.addCallback(lambda res: e.start())
+        # this fires with the uri_extension_hash and other data
         return d
 
-    def setup_encoder(self, size):
-        self._size = size
-        self._encoder = encode.Encoder(self._options)
-        self._encoder.set_size(size)
-        self._encoder.set_params(self._encoding_parameters)
-        self._encoder.set_uploadable(self._uploadable)
-        self._encoder.setup()
-        return self._encoder.get_serialized_params()
-
-    def setup_keys(self, key):
-        assert isinstance(key, str)
-        assert len(key) == 16  # AES-128
-        self._encryption_key = key
-        self._encoder.set_encryption_key(key)
-        storage_index = hashutil.storage_index_chk_hash(key)
-        assert isinstance(storage_index, str)
-        # There's no point to having the SI be longer than the key, so we
-        # specify that it is truncated to the same 128 bits as the AES key.
-        assert len(storage_index) == 16  # SHA-256 truncated to 128b
-        self._storage_index = storage_index
-        log.msg(" upload storage_index is [%s]" % (idlib.b2a(storage_index,)))
-
-
-    def locate_all_shareholders(self, ignored=None):
+    def locate_all_shareholders(self, encoder):
         peer_selector = self.peer_selector_class()
-        share_size = self._encoder.get_share_size()
-        block_size = self._encoder.get_block_size()
-        num_segments = self._encoder.get_num_segments()
+
+        storage_index = encoder.get_param("storage_index")
+        share_size = encoder.get_param("share_size")
+        block_size = encoder.get_param("block_size")
+        num_segments = encoder.get_param("num_segments")
+        k,desired,n = encoder.get_param("share_counts")
+
         gs = peer_selector.get_shareholders
-        d = gs(self._client,
-               self._storage_index, share_size, block_size,
-               num_segments, self.total_shares, self.shares_of_happiness)
+        d = gs(self._client, storage_index, share_size, block_size,
+               num_segments, n, desired)
         return d
 
-    def set_shareholders(self, used_peers):
+    def set_shareholders(self, used_peers, encoder):
         """
         @param used_peers: a sequence of PeerTracker objects
         """
@@ -317,16 +415,17 @@ class CHKUploader:
         for peer in used_peers:
             buckets.update(peer.buckets)
         assert len(buckets) == sum([len(peer.buckets) for peer in used_peers])
-        self._encoder.set_shareholders(buckets)
+        encoder.set_shareholders(buckets)
 
-    def _compute_uri(self, uri_extension_hash):
-        u = uri.CHKFileURI(key=self._encryption_key,
+    def _compute_uri(self, (uri_extension_hash,
+                            needed_shares, total_shares, size),
+                     key):
+        u = uri.CHKFileURI(key=key,
                            uri_extension_hash=uri_extension_hash,
-                           needed_shares=self.needed_shares,
-                           total_shares=self.total_shares,
-                           size=self._size,
+                           needed_shares=needed_shares,
+                           total_shares=total_shares,
+                           size=size,
                            )
-        assert u.storage_index == self._storage_index
         return u.to_string()
 
 def read_this_many_bytes(uploadable, size, prepend_data=[]):
@@ -348,17 +447,17 @@ def read_this_many_bytes(uploadable, size, prepend_data=[]):
 
 class LiteralUploader:
 
-    def __init__(self, client, uploadable, options={}):
+    def __init__(self, client, options={}):
         self._client = client
-        self._uploadable = IUploadable(uploadable)
         self._options = options
 
     def set_params(self, encoding_parameters):
         pass
 
-    def start(self):
-        d = self._uploadable.get_size()
-        d.addCallback(lambda size: read_this_many_bytes(self._uploadable, size))
+    def start(self, uploadable):
+        uploadable = IUploadable(uploadable)
+        d = uploadable.get_size()
+        d.addCallback(lambda size: read_this_many_bytes(uploadable, size))
         d.addCallback(lambda data: uri.LiteralFileURI("".join(data)))
         d.addCallback(lambda u: u.to_string())
         return d
@@ -370,25 +469,40 @@ class LiteralUploader:
 class ConvergentUploadMixin:
     # to use this, the class it is mixed in to must have a seekable
     # filehandle named self._filehandle
-
-    def get_encryption_key(self, encoding_parameters):
-        f = self._filehandle
-        enckey_hasher = hashutil.key_hasher()
-        #enckey_hasher.update(encoding_parameters) # TODO
-        f.seek(0)
-        BLOCKSIZE = 64*1024
-        while True:
-            data = f.read(BLOCKSIZE)
-            if not data:
-                break
-            enckey_hasher.update(data)
-        enckey = enckey_hasher.digest()[:16]
-        f.seek(0)
-        return defer.succeed(enckey)
+    _params = None
+    _key = None
+
+    def set_serialized_encoding_parameters(self, params):
+        self._params = params
+        # ignored for now
+
+    def get_encryption_key(self):
+        if self._key is None:
+            f = self._filehandle
+            enckey_hasher = hashutil.key_hasher()
+            #enckey_hasher.update(encoding_parameters) # TODO
+            f.seek(0)
+            BLOCKSIZE = 64*1024
+            while True:
+                data = f.read(BLOCKSIZE)
+                if not data:
+                    break
+                enckey_hasher.update(data)
+            f.seek(0)
+            self._key = enckey_hasher.digest()[:16]
+
+        return defer.succeed(self._key)
 
 class NonConvergentUploadMixin:
+    _key = None
+
+    def set_serialized_encoding_parameters(self, params):
+        pass
+
     def get_encryption_key(self, encoding_parameters):
-        return defer.succeed(os.urandom(16))
+        if self._key is None:
+            self._key = os.urandom(16)
+        return defer.succeed(self._key)
 
 
 class FileHandle(ConvergentUploadMixin):
@@ -446,10 +560,10 @@ class Uploader(service.MultiService):
             uploader_class = self.uploader_class
             if size <= self.URI_LIT_SIZE_THRESHOLD:
                 uploader_class = LiteralUploader
-            uploader = uploader_class(self.parent, uploadable, options)
+            uploader = uploader_class(self.parent, options)
             uploader.set_params(self.parent.get_encoding_parameters()
                                 or self.DEFAULT_ENCODING_PARAMETERS)
-            return uploader.start()
+            return uploader.start(uploadable)
         d.addCallback(_got_size)
         def _done(res):
             uploadable.close()