]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
upload: refactor to enable streaming upload. not all tests pass yet
authorBrian Warner <warner@allmydata.com>
Fri, 20 Jul 2007 01:21:44 +0000 (18:21 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 20 Jul 2007 01:21:44 +0000 (18:21 -0700)
src/allmydata/encode.py
src/allmydata/interfaces.py
src/allmydata/test/test_web.py
src/allmydata/upload.py
src/allmydata/util/hashutil.py

index ce470c95354ff858c1f829113fa568fdcf6b7f59..3290aec7b2d8e309f68982767e3ea1b2d1f52d2f 100644 (file)
@@ -88,20 +88,19 @@ class Encoder(object):
         self.TOTAL_SHARES = n
         self.uri_extension_data = {}
 
+    def set_size(self, size):
+        self.file_size = size
+
     def set_params(self, encoding_parameters):
         k,d,n = encoding_parameters
         self.NEEDED_SHARES = k
         self.SHARES_OF_HAPPINESS = d
         self.TOTAL_SHARES = n
 
-    def setup(self, infile, encryption_key):
-        self.infile = infile
-        assert isinstance(encryption_key, str)
-        assert len(encryption_key) == 16 # AES-128
-        self.key = encryption_key
-        infile.seek(0, 2)
-        self.file_size = infile.tell()
-        infile.seek(0, 0)
+    def set_uploadable(self, uploadable):
+        self._uploadable = uploadable
+
+    def setup(self):
 
         self.num_shares = self.TOTAL_SHARES
         self.required_shares = self.NEEDED_SHARES
@@ -111,10 +110,13 @@ class Encoder(object):
         # this must be a multiple of self.required_shares
         self.segment_size = mathutil.next_multiple(self.segment_size,
                                                    self.required_shares)
-        self.setup_codec()
+        self._setup_codec()
 
-    def setup_codec(self):
+    def _setup_codec(self):
         assert self.segment_size % self.required_shares == 0
+        self.num_segments = mathutil.div_ceil(self.file_size,
+                                              self.segment_size)
+
         self._codec = CRSEncoder()
         self._codec.set_params(self.segment_size,
                                self.required_shares, self.num_shares)
@@ -125,8 +127,9 @@ class Encoder(object):
 
         data['size'] = self.file_size
         data['segment_size'] = self.segment_size
-        data['num_segments'] = mathutil.div_ceil(self.file_size,
-                                                 self.segment_size)
+        self.share_size = mathutil.div_ceil(self.file_size,
+                                            self.required_shares)
+        data['num_segments'] = self.num_segments
         data['needed_shares'] = self.required_shares
         data['total_shares'] = self.num_shares
 
@@ -147,8 +150,13 @@ class Encoder(object):
                                     self.required_shares, self.num_shares)
         data['tail_codec_params'] = self._tail_codec.get_serialized_params()
 
-    def set_uri_extension_data(self, uri_extension_data):
-        self.uri_extension_data.update(uri_extension_data)
+    def get_serialized_params(self):
+        return self._codec.get_serialized_params()
+
+    def set_encryption_key(self, key):
+        assert isinstance(key, str)
+        assert len(key) == 16 # AES-128
+        self.key = key
 
     def get_share_size(self):
         share_size = mathutil.div_ceil(self.file_size, self.required_shares)
@@ -158,6 +166,8 @@ class Encoder(object):
         return 0
     def get_block_size(self):
         return self._codec.get_block_size()
+    def get_num_segments(self):
+        return self.num_segments
 
     def set_shareholders(self, landlords):
         assert isinstance(landlords, dict)
@@ -167,14 +177,11 @@ class Encoder(object):
 
     def start(self):
         #paddedsize = self._size + mathutil.pad_size(self._size, self.needed_shares)
-        self.num_segments = mathutil.div_ceil(self.file_size,
-                                              self.segment_size)
-        self.share_size = mathutil.div_ceil(self.file_size,
-                                            self.required_shares)
+        self._plaintext_hasher = hashutil.plaintext_hasher()
         self._plaintext_hashes = []
+        self._crypttext_hasher = hashutil.crypttext_hasher()
         self._crypttext_hashes = []
         self.setup_encryption()
-        self.setup_codec() # TODO: duplicate call?
         d = defer.succeed(None)
 
         for l in self.landlords.values():
@@ -185,8 +192,13 @@ class Encoder(object):
             # captures the slot, not the value
             #d.addCallback(lambda res: self.do_segment(i))
             # use this form instead:
-            d.addCallback(lambda res, i=i: self.do_segment(i))
-        d.addCallback(lambda res: self.do_tail_segment(self.num_segments-1))
+            d.addCallback(lambda res, i=i: self._encode_segment(i))
+            d.addCallback(self._send_segment, i)
+        last_segnum = self.num_segments - 1
+        d.addCallback(lambda res: self._encode_tail_segment(last_segnum))
+        d.addCallback(self._send_segment, last_segnum)
+
+        d.addCallback(lambda res: self.finish_flat_hashes())
 
         d.addCallback(lambda res:
                       self.send_plaintext_hash_tree_to_all_shareholders())
@@ -195,6 +207,7 @@ class Encoder(object):
         d.addCallback(lambda res: self.send_all_subshare_hash_trees())
         d.addCallback(lambda res: self.send_all_share_hash_trees())
         d.addCallback(lambda res: self.send_uri_extension_to_all_shareholders())
+
         d.addCallback(lambda res: self.close_all_shareholders())
         d.addCallbacks(lambda res: self.done(), self.err)
         return d
@@ -209,9 +222,9 @@ class Encoder(object):
         # that we sent to that landlord.
         self.share_root_hashes = [None] * self.num_shares
 
-    def do_segment(self, segnum):
-        chunks = []
+    def _encode_segment(self, segnum):
         codec = self._codec
+
         # the ICodecEncoder API wants to receive a total of self.segment_size
         # bytes on each encode() call, broken up into a number of
         # identically-sized pieces. Due to the way the codec algorithm works,
@@ -228,8 +241,8 @@ class Encoder(object):
         # of additional shares which can be substituted if the primary ones
         # are unavailable
 
-        plaintext_hasher = hashutil.plaintext_segment_hasher()
-        crypttext_hasher = hashutil.crypttext_segment_hasher()
+        plaintext_segment_hasher = hashutil.plaintext_segment_hasher()
+        crypttext_segment_hasher = hashutil.crypttext_segment_hasher()
 
         # memory footprint: we only hold a tiny piece of the plaintext at any
         # given time. We build up a segment's worth of cryptttext, then hand
@@ -238,56 +251,92 @@ class Encoder(object):
         # 10MiB. Lowering max_segment_size to, say, 100KiB would drop the
         # footprint to 500KiB at the expense of more hash-tree overhead.
 
-        for i in range(self.required_shares):
-            input_piece = self.infile.read(input_piece_size)
-            # non-tail segments should be the full segment size
-            assert len(input_piece) == input_piece_size
-            plaintext_hasher.update(input_piece)
-            encrypted_piece = self.cryptor.encrypt(input_piece)
-            assert len(encrypted_piece) == len(input_piece)
-            crypttext_hasher.update(encrypted_piece)
-
-            chunks.append(encrypted_piece)
-
-        self._plaintext_hashes.append(plaintext_hasher.digest())
-        self._crypttext_hashes.append(crypttext_hasher.digest())
-
-        d = codec.encode(chunks) # during this call, we hit 5*segsize memory
-        del chunks
-        d.addCallback(self._encoded_segment, segnum)
+        d = self._gather_data(self.required_shares, input_piece_size,
+                              plaintext_segment_hasher,
+                              crypttext_segment_hasher)
+        def _done(chunks):
+            for c in chunks:
+                assert len(c) == input_piece_size
+            self._plaintext_hashes.append(plaintext_segment_hasher.digest())
+            self._crypttext_hashes.append(crypttext_segment_hasher.digest())
+            # during this call, we hit 5*segsize memory
+            return codec.encode(chunks)
+        d.addCallback(_done)
         return d
 
-    def do_tail_segment(self, segnum):
-        chunks = []
+    def _encode_tail_segment(self, segnum):
+
         codec = self._tail_codec
         input_piece_size = codec.get_block_size()
 
-        plaintext_hasher = hashutil.plaintext_segment_hasher()
-        crypttext_hasher = hashutil.crypttext_segment_hasher()
-
-        for i in range(self.required_shares):
-            input_piece = self.infile.read(input_piece_size)
-            plaintext_hasher.update(input_piece)
-            encrypted_piece = self.cryptor.encrypt(input_piece)
-            assert len(encrypted_piece) == len(input_piece)
-            crypttext_hasher.update(encrypted_piece)
-
-            if len(encrypted_piece) < input_piece_size:
-                # padding
-                pad_size = (input_piece_size - len(encrypted_piece))
-                encrypted_piece += ('\x00' * pad_size)
-
-            chunks.append(encrypted_piece)
-
-        self._plaintext_hashes.append(plaintext_hasher.digest())
-        self._crypttext_hashes.append(crypttext_hasher.digest())
+        plaintext_segment_hasher = hashutil.plaintext_segment_hasher()
+        crypttext_segment_hasher = hashutil.crypttext_segment_hasher()
+
+        d = self._gather_data(self.required_shares, input_piece_size,
+                              plaintext_segment_hasher,
+                              crypttext_segment_hasher,
+                              allow_short=True)
+        def _done(chunks):
+            for c in chunks:
+                # a short trailing chunk will have been padded by
+                # _gather_data
+                assert len(c) == input_piece_size
+            self._plaintext_hashes.append(plaintext_segment_hasher.digest())
+            self._crypttext_hashes.append(crypttext_segment_hasher.digest())
+            return codec.encode(chunks)
+        d.addCallback(_done)
+        return d
 
-        d = codec.encode(chunks)
-        del chunks
-        d.addCallback(self._encoded_segment, segnum)
+    def _gather_data(self, num_chunks, input_chunk_size,
+                     plaintext_segment_hasher, crypttext_segment_hasher,
+                     allow_short=False,
+                     previous_chunks=[]):
+        """Return a Deferred that will fire when the required number of
+        chunks have been read (and hashed and encrypted). The Deferred fires
+        with the combination of any 'previous_chunks' and the new chunks
+        which were gathered."""
+
+        if not num_chunks:
+            return defer.succeed(previous_chunks)
+
+        d = self._uploadable.read(input_chunk_size)
+        def _got(data):
+            encrypted_pieces = []
+            length = 0
+            # we use data.pop(0) instead of 'for input_piece in data' to save
+            # memory: each piece is destroyed as soon as we're done with it.
+            while data:
+                input_piece = data.pop(0)
+                length += len(input_piece)
+                plaintext_segment_hasher.update(input_piece)
+                self._plaintext_hasher.update(input_piece)
+                encrypted_piece = self.cryptor.encrypt(input_piece)
+                assert len(encrypted_piece) == len(input_piece)
+                crypttext_segment_hasher.update(encrypted_piece)
+                self._crypttext_hasher.update(encrypted_piece)
+                encrypted_pieces.append(encrypted_piece)
+
+            if allow_short:
+                if length < input_chunk_size:
+                    # padding
+                    pad_size = input_chunk_size - length
+                    encrypted_pieces.append('\x00' * pad_size)
+            else:
+                # non-tail segments should be the full segment size
+                assert length == input_chunk_size
+
+            encrypted_piece = "".join(encrypted_pieces)
+            return previous_chunks + [encrypted_piece]
+
+        d.addCallback(_got)
+        d.addCallback(lambda chunks:
+                      self._gather_data(num_chunks-1, input_chunk_size,
+                                        plaintext_segment_hasher,
+                                        crypttext_segment_hasher,
+                                        allow_short, chunks))
         return d
 
-    def _encoded_segment(self, (shares, shareids), segnum):
+    def _send_segment(self, (shares, shareids), segnum):
         # To generate the URI, we must generate the roothash, so we must
         # generate all shares, even if we aren't actually giving them to
         # anybody. This means that the set of shares we create will be equal
@@ -354,6 +403,12 @@ class Encoder(object):
             d0.addErrback(_eatNotEnoughPeersError)
         return d
 
+    def finish_flat_hashes(self):
+        plaintext_hash = self._plaintext_hasher.digest()
+        crypttext_hash = self._crypttext_hasher.digest()
+        self.uri_extension_data["plaintext_hash"] = plaintext_hash
+        self.uri_extension_data["crypttext_hash"] = crypttext_hash
+
     def send_plaintext_hash_tree_to_all_shareholders(self):
         log.msg("%s sending plaintext hash tree" % self)
         t = HashTree(self._plaintext_hashes)
@@ -445,6 +500,10 @@ class Encoder(object):
 
     def send_uri_extension_to_all_shareholders(self):
         log.msg("%s: sending uri_extension" % self)
+        for k in ('crypttext_root_hash', 'crypttext_hash',
+                  'plaintext_root_hash', 'plaintext_hash',
+                  ):
+            assert k in self.uri_extension_data
         uri_extension = uri.pack_extension(self.uri_extension_data)
         self.uri_extension_hash = hashutil.uri_extension_hash(uri_extension)
         dl = []
index 2b4f4829f06d035e39c108aaeabc482ba8cf325c..1d4d246871652caf90ac21e003e4d8ca064f0b28 100644 (file)
@@ -679,14 +679,54 @@ class IDownloader(Interface):
         when the download is finished, or errbacks if something went wrong."""
 
 class IUploadable(Interface):
-    def get_filehandle():
-        """Return a filehandle from which the data to be uploaded can be
-        read. It must implement .read, .seek, and .tell (since the latter two
-        are used to determine the length of the data)."""
-    def close_filehandle(f):
-        """The upload is finished. This provides the same filehandle as was
-        returned by get_filehandle. This is an appropriate place to close the
-        filehandle."""
+    def get_size():
+        """Return a Deferred that will fire with the length of the data to be
+        uploaded, in bytes. This will be called before the data is actually
+        used, to compute encoding parameters.
+        """
+
+    def get_encryption_key(encoding_parameters):
+        """Return a Deferred that fires with a 16-byte AES key. This key will
+        be used to encrypt the data. The key will also be hashed to derive
+        the StorageIndex. 'encoding_parameters' is a string which indicates
+        how the data will be encoded (codec name, blocksize, number of
+        shares): Uploadables may wish to use these parameters while computing
+        the encryption key.
+
+        Uploadables which want to achieve convergence should hash their file
+        contents and the encoding_parameters to form the key (which of course
+        requires a full pass over the data). Uploadables can use the
+        upload.ConvergentUploadMixin class to achieve this automatically.
+
+        Uploadables which do not care about convergence (or do not wish to
+        make multiple passes over the data) can simply return a
+        strongly-random 16 byte string.
+        """
+
+    def read(length):
+        """Return a Deferred that fires with a list of strings (perhaps with
+        only a single element) which, when concatenated together, contain the
+        next 'length' bytes of data. If EOF is near, this may provide fewer
+        than 'length' bytes. The total number of bytes provided by read()
+        before it signals EOF must equal the size provided by get_size().
+
+        If the data must be acquired through multiple internal read
+        operations, returning a list instead of a single string may help to
+        reduce string copies.
+
+        'length' will typically be equal to (min(get_size(),1MB)/req_shares),
+        so a 10kB file means length=3kB, 100kB file means length=30kB,
+        and >=1MB file means length=300kB.
+
+        This method provides for a single full pass through the data. Later
+        use cases may desire multiple passes or access to only parts of the
+        data (such as a mutable file making small edits-in-place). This API
+        will be expanded once those use cases are better understood.
+        """
+
+    def close():
+        """The upload is finished, and whatever filehandle was in use may be
+        closed."""
 
 class IUploader(Interface):
     def upload(uploadable):
index b6182949a6844cd679a35479176fd711c88b971b..0b5046ff99ae4697fdcb172732cd2b6c4dea5469 100644 (file)
@@ -55,12 +55,15 @@ class MyUploader(service.Service):
         self.files = files
 
     def upload(self, uploadable):
-        f = uploadable.get_filehandle()
-        data = f.read()
-        uri = str(uri_counter.next())
-        self.files[uri] = data
-        uploadable.close_filehandle(f)
-        return defer.succeed(uri)
+        d = uploadable.get_size()
+        d.addCallback(lambda size: uploadable.read(size))
+        d.addCallback(lambda data: "".join(data))
+        def _got_data(data):
+            uri = str(uri_counter.next())
+            self.files[uri] = data
+            uploadable.close()
+        d.addCallback(_got_data)
+        return d
 
 class MyDirectoryNode(dirnode.MutableDirectoryNode):
 
@@ -94,15 +97,18 @@ class MyDirectoryNode(dirnode.MutableDirectoryNode):
         return defer.succeed(None)
 
     def add_file(self, name, uploadable):
-        f = uploadable.get_filehandle()
-        data = f.read()
-        uri = str(uri_counter.next())
-        self._my_files[uri] = data
-        self._my_nodes[uri] = MyFileNode(uri, self._my_client)
-        uploadable.close_filehandle(f)
-
-        self.children[name] = uri
-        return defer.succeed(self._my_nodes[uri])
+        d = uploadable.get_size()
+        d.addCallback(lambda size: uploadable.read(size))
+        d.addCallback(lambda data: "".join(data))
+        def _got_data(data):
+            uri = str(uri_counter.next())
+            self._my_files[uri] = data
+            self._my_nodes[uri] = MyFileNode(uri, self._my_client)
+            self.children[name] = uri
+            uploadable.close()
+            return self._my_nodes[uri]
+        d.addCallback(_got_data)
+        return d
 
     def delete(self, name):
         def _try():
index bebb46496eb5b16d788872aa355265217f5453bb..af37309d80f3cf3b5daad74440261cbab4c3366c 100644 (file)
@@ -1,3 +1,5 @@
+
+import os
 from zope.interface import implements
 from twisted.python import log
 from twisted.internet import defer
@@ -8,7 +10,6 @@ from allmydata.util import idlib, hashutil
 from allmydata import encode, storage, hashtree
 from allmydata.uri import pack_uri, pack_lit
 from allmydata.interfaces import IUploadable, IUploader
-from allmydata.Crypto.Cipher import AES
 
 from cStringIO import StringIO
 import collections, random
@@ -83,94 +84,39 @@ class PeerTracker:
         self.buckets.update(b)
         return (alreadygot, set(b.keys()))
 
-class FileUploader:
-
-    def __init__(self, client, options={}):
-        self._client = client
-        self._options = options
-
-    def set_params(self, encoding_parameters):
-        self._encoding_parameters = encoding_parameters
-
-        needed_shares, shares_of_happiness, total_shares = encoding_parameters
-        self.needed_shares = needed_shares
-        self.shares_of_happiness = shares_of_happiness
-        self.total_shares = total_shares
-
-    def set_filehandle(self, filehandle):
-        self._filehandle = filehandle
-        filehandle.seek(0, 2)
-        self._size = filehandle.tell()
-        filehandle.seek(0)
-
-    def set_id_strings(self, crypttext_hash, plaintext_hash):
-        assert isinstance(crypttext_hash, str)
-        assert len(crypttext_hash) == 32
-        self._crypttext_hash = crypttext_hash
-        assert isinstance(plaintext_hash, str)
-        assert len(plaintext_hash) == 32
-        self._plaintext_hash = plaintext_hash
-
-    def set_encryption_key(self, key):
-        assert isinstance(key, str)
-        assert len(key) == 16  # AES-128
-        self._encryption_key = key
-
-    def start(self):
-        """Start uploading the file.
-
-        The source of the data to be uploaded must have been set before this
-        point by calling set_filehandle().
-
-        This method returns a Deferred that will fire with the URI (a
-        string)."""
-
-        log.msg("starting upload [%s]" % (idlib.b2a(self._crypttext_hash),))
-        assert self.needed_shares
-
-        # create the encoder, so we can know how large the shares will be
-        share_size, block_size = self.setup_encoder()
-
-        d = self._locate_all_shareholders(share_size, block_size)
-        d.addCallback(self._send_shares)
-        d.addCallback(self._compute_uri)
-        return d
-
-    def setup_encoder(self):
-        self._encoder = encode.Encoder(self._options)
-        self._encoder.set_params(self._encoding_parameters)
-        self._encoder.setup(self._filehandle, self._encryption_key)
-        share_size = self._encoder.get_share_size()
-        block_size = self._encoder.get_block_size()
-        return share_size, block_size
+class Tahoe3PeerSelector:
 
-    def _locate_all_shareholders(self, share_size, block_size):
+    def get_shareholders(self, client,
+                         storage_index, share_size, block_size,
+                         num_segments, total_shares, shares_of_happiness):
         """
         @return: a set of PeerTracker instances that have agreed to hold some
             shares for us
         """
+
+        self.total_shares = total_shares
+        self.shares_of_happiness = shares_of_happiness
+
         # we are responsible for locating the shareholders. self._encoder is
         # responsible for handling the data and sending out the shares.
-        peers = self._client.get_permuted_peers(self._crypttext_hash)
+        peers = client.get_permuted_peers(storage_index)
         assert peers
 
-        # TODO: eek, don't pull this from here, find a better way. gross.
-        num_segments = self._encoder.uri_extension_data['num_segments']
-        ht = hashtree.IncompleteHashTree(self.total_shares)
         # this needed_hashes computation should mirror
         # Encoder.send_all_share_hash_trees. We use an IncompleteHashTree
         # (instead of a HashTree) because we don't require actual hashing
         # just to count the levels.
+        ht = hashtree.IncompleteHashTree(total_shares)
         num_share_hashes = len(ht.needed_hashes(0, include_leaf=True))
 
         trackers = [ PeerTracker(peerid, permutedid, conn,
                                  share_size, block_size,
                                  num_segments, num_share_hashes,
-                                 self._crypttext_hash)
+                                 storage_index)
                      for permutedid, peerid, conn in peers ]
         self.usable_peers = set(trackers) # this set shrinks over time
         self.used_peers = set() # while this set grows
-        self.unallocated_sharenums = set(range(self.total_shares)) # this one shrinks
+        self.unallocated_sharenums = set(range(total_shares)) # this one shrinks
 
         return self._locate_more_shareholders()
 
@@ -181,18 +127,23 @@ class FileUploader:
 
     def _located_some_shareholders(self, res):
         log.msg("_located_some_shareholders")
-        log.msg(" still need homes for %d shares, still have %d usable peers" % (len(self.unallocated_sharenums), len(self.usable_peers)))
+        log.msg(" still need homes for %d shares, still have %d usable peers"
+                % (len(self.unallocated_sharenums), len(self.usable_peers)))
         if not self.unallocated_sharenums:
             # Finished allocating places for all shares.
-            log.msg("%s._locate_all_shareholders() Finished allocating places for all shares." % self)
+            log.msg("%s._locate_all_shareholders() "
+                    "Finished allocating places for all shares." % self)
             log.msg("used_peers is %s" % (self.used_peers,))
             return self.used_peers
         if not self.usable_peers:
             # Ran out of peers who have space.
-            log.msg("%s._locate_all_shareholders() Ran out of peers who have space." % self)
-            if len(self.unallocated_sharenums) < (self.total_shares - self.shares_of_happiness):
+            log.msg("%s._locate_all_shareholders() "
+                    "Ran out of peers who have space." % self)
+            margin = self.total_shares - self.shares_of_happiness
+            if len(self.unallocated_sharenums) < margin:
                 # But we allocated places for enough shares.
-                log.msg("%s._locate_all_shareholders() But we allocated places for enough shares.")
+                log.msg("%s._locate_all_shareholders() "
+                        "But we allocated places for enough shares.")
                 return self.used_peers
             raise encode.NotEnoughPeersError
         # we need to keep trying
@@ -201,7 +152,10 @@ class FileUploader:
     def _create_ring_of_things(self):
         PEER = 1 # must sort later than SHARE, for consistency with download
         SHARE = 0
-        ring_of_things = [] # a list of (position_in_ring, whatami, x) where whatami is SHARE if x is a sharenum or else PEER if x is a PeerTracker instance
+        # ring_of_things is a list of (position_in_ring, whatami, x) where
+        # whatami is SHARE if x is a sharenum or else PEER if x is a
+        # PeerTracker instance
+        ring_of_things = []
         ring_of_things.extend([ (peer.permutedid, PEER, peer,)
                                 for peer in self.usable_peers ])
         shares = [ (i * 2**160 / self.total_shares, SHARE, i)
@@ -258,7 +212,11 @@ class FileUploader:
         # sets into sets.Set on us, even when we're using 2.4
         alreadygot = set(alreadygot)
         allocated = set(allocated)
-        #log.msg("%s._got_response(%s, %s, %s): self.unallocated_sharenums: %s, unhandled: %s" % (self, (alreadygot, allocated), peer, shares_we_requested, self.unallocated_sharenums, shares_we_requested - alreadygot - allocated))
+        #log.msg("%s._got_response(%s, %s, %s): "
+        #        "self.unallocated_sharenums: %s, unhandled: %s"
+        #        % (self, (alreadygot, allocated), peer, shares_we_requested,
+        #           self.unallocated_sharenums,
+        #           shares_we_requested - alreadygot - allocated))
         self.unallocated_sharenums -= alreadygot
         self.unallocated_sharenums -= allocated
 
@@ -266,15 +224,90 @@ class FileUploader:
             self.used_peers.add(peer)
 
         if shares_we_requested - alreadygot - allocated:
-            #log.msg("%s._got_response(%s, %s, %s): self.unallocated_sharenums: %s, unhandled: %s HE'S FULL" % (self, (alreadygot, allocated), peer, shares_we_requested, self.unallocated_sharenums, shares_we_requested - alreadygot - allocated))
             # Then he didn't accept some of the shares, so he's full.
+
+            #log.msg("%s._got_response(%s, %s, %s): "
+            #        "self.unallocated_sharenums: %s, unhandled: %s HE'S FULL"
+            #        % (self,
+            #           (alreadygot, allocated), peer, shares_we_requested,
+            #           self.unallocated_sharenums,
+            #           shares_we_requested - alreadygot - allocated))
             self.usable_peers.remove(peer)
 
     def _got_error(self, f, peer):
         log.msg("%s._got_error(%s, %s)" % (self, f, peer,))
         self.usable_peers.remove(peer)
 
-    def _send_shares(self, used_peers):
+
+class CHKUploader:
+    peer_selector_class = Tahoe3PeerSelector
+
+    def __init__(self, client, uploadable, options={}):
+        self._client = client
+        self._uploadable = IUploadable(uploadable)
+        self._options = options
+
+    def set_params(self, encoding_parameters):
+        self._encoding_parameters = encoding_parameters
+
+        needed_shares, shares_of_happiness, total_shares = encoding_parameters
+        self.needed_shares = needed_shares
+        self.shares_of_happiness = shares_of_happiness
+        self.total_shares = total_shares
+
+    def start(self):
+        """Start uploading the file.
+
+        This method returns a Deferred that will fire with the URI (a
+        string)."""
+
+        log.msg("starting upload of %s" % self._uploadable)
+
+        d = self._uploadable.get_size()
+        d.addCallback(self.setup_encoder)
+        d.addCallback(self._uploadable.get_encryption_key)
+        d.addCallback(self.setup_keys)
+        d.addCallback(self.locate_all_shareholders)
+        d.addCallback(self.set_shareholders)
+        d.addCallback(lambda res: self._encoder.start())
+        d.addCallback(self._compute_uri)
+        return d
+
+    def setup_encoder(self, size):
+        self._size = size
+        self._encoder = encode.Encoder(self._options)
+        self._encoder.set_size(size)
+        self._encoder.set_params(self._encoding_parameters)
+        self._encoder.set_uploadable(self._uploadable)
+        self._encoder.setup()
+        return self._encoder.get_serialized_params()
+
+    def setup_keys(self, key):
+        assert isinstance(key, str)
+        assert len(key) == 16  # AES-128
+        self._encryption_key = key
+        self._encoder.set_encryption_key(key)
+        storage_index = hashutil.storage_index_chk_hash(key)
+        assert isinstance(storage_index, str)
+        # TODO: is there any point to having the SI be longer than the key?
+        # There's certainly no extra entropy to be had..
+        assert len(storage_index) == 32  # SHA-256
+        self._storage_index = storage_index
+        log.msg(" upload SI is [%s]" % (idlib.b2a(storage_index,)))
+
+
+    def locate_all_shareholders(self, ignored=None):
+        peer_selector = self.peer_selector_class()
+        share_size = self._encoder.get_share_size()
+        block_size = self._encoder.get_block_size()
+        num_segments = self._encoder.get_num_segments()
+        gs = peer_selector.get_shareholders
+        d = gs(self._client,
+               self._storage_index, share_size, block_size,
+               num_segments, self.total_shares, self.shares_of_happiness)
+        return d
+
+    def set_shareholders(self, used_peers):
         """
         @param used_peers: a sequence of PeerTracker objects
         """
@@ -287,14 +320,8 @@ class FileUploader:
         assert len(buckets) == sum([len(peer.buckets) for peer in used_peers])
         self._encoder.set_shareholders(buckets)
 
-        uri_extension_data = {}
-        uri_extension_data['crypttext_hash'] = self._crypttext_hash
-        uri_extension_data['plaintext_hash'] = self._plaintext_hash
-        self._encoder.set_uri_extension_data(uri_extension_data)
-        return self._encoder.start()
-
     def _compute_uri(self, uri_extension_hash):
-        return pack_uri(storage_index=self._crypttext_hash,
+        return pack_uri(storage_index=self._storage_index,
                         key=self._encryption_key,
                         uri_extension_hash=uri_extension_hash,
                         needed_shares=self.needed_shares,
@@ -302,55 +329,101 @@ class FileUploader:
                         size=self._size,
                         )
 
+def read_this_many_bytes(uploadable, size, prepend_data=[]):
+    d = uploadable.read(size)
+    def _got(data):
+        assert isinstance(list)
+        bytes = sum([len(piece) for piece in data])
+        assert bytes > 0
+        assert bytes <= size
+        remaining = size - bytes
+        if remaining:
+            return read_this_many_bytes(uploadable, remaining,
+                                        prepend_data + data)
+        return prepend_data + data
+    d.addCallback(_got)
+    return d
+
 class LiteralUploader:
 
-    def __init__(self, client, options={}):
+    def __init__(self, client, uploadable, options={}):
         self._client = client
+        self._uploadable = IUploadable(uploadable)
         self._options = options
 
-    def set_filehandle(self, filehandle):
-        self._filehandle = filehandle
+    def set_params(self, encoding_parameters):
+        pass
 
     def start(self):
-        self._filehandle.seek(0)
-        data = self._filehandle.read()
-        return defer.succeed(pack_lit(data))
+        d = self._uploadable.get_size()
+        d.addCallback(lambda size: read_this_many_bytes(self._uploadable, size))
+        d.addCallback(lambda data: pack_lit("".join(data)))
+        return d
 
+    def close(self):
+        pass
 
-class FileName:
-    implements(IUploadable)
-    def __init__(self, filename):
-        self._filename = filename
-    def get_filehandle(self):
-        return open(self._filename, "rb")
-    def close_filehandle(self, f):
-        f.close()
 
-class Data:
-    implements(IUploadable)
-    def __init__(self, data):
-        self._data = data
-    def get_filehandle(self):
-        return StringIO(self._data)
-    def close_filehandle(self, f):
-        pass
+class ConvergentUploadMixin:
+    # to use this, the class it is mixed in to must have a seekable
+    # filehandle named self._filehandle
 
-class FileHandle:
+    def get_encryption_key(self, encoding_parameters):
+        f = self._filehandle
+        enckey_hasher = hashutil.key_hasher()
+        #enckey_hasher.update(encoding_parameters) # TODO
+        f.seek(0)
+        BLOCKSIZE = 64*1024
+        while True:
+            data = f.read(BLOCKSIZE)
+            if not data:
+                break
+            enckey_hasher.update(data)
+        enckey = enckey_hasher.digest()[:16]
+        f.seek(0)
+        return defer.succeed(enckey)
+
+class NonConvergentUploadMixin:
+    def get_encryption_key(self, encoding_parameters):
+        return defer.succeed(os.urandom(16))
+
+
+class FileHandle(ConvergentUploadMixin):
     implements(IUploadable)
+
     def __init__(self, filehandle):
         self._filehandle = filehandle
-    def get_filehandle(self):
-        return self._filehandle
-    def close_filehandle(self, f):
+
+    def get_size(self):
+        self._filehandle.seek(0,2)
+        size = self._filehandle.tell()
+        self._filehandle.seek(0)
+        return defer.succeed(size)
+
+    def read(self, length):
+        return defer.succeed([self._filehandle.read(length)])
+
+    def close(self):
         # the originator of the filehandle reserves the right to close it
         pass
 
+class FileName(FileHandle):
+    def __init__(self, filename):
+        FileHandle.__init__(self, open(filename, "rb"))
+    def close(self):
+        FileHandle.close(self)
+        self._filehandle.close()
+
+class Data(FileHandle):
+    def __init__(self, data):
+        FileHandle.__init__(self, StringIO(data))
+
 class Uploader(service.MultiService):
     """I am a service that allows file uploading.
     """
     implements(IUploader)
     name = "uploader"
-    uploader_class = FileUploader
+    uploader_class = CHKUploader
     URI_LIT_SIZE_THRESHOLD = 55
 
     DEFAULT_ENCODING_PARAMETERS = (25, 75, 100)
@@ -360,65 +433,23 @@ class Uploader(service.MultiService):
     # 'total' is the total number of shares created by encoding. If everybody
     # has room then this is is how many we will upload.
 
-    def compute_id_strings(self, f):
-        # return a list of (plaintext_hash, encryptionkey, crypttext_hash)
-        plaintext_hasher = hashutil.plaintext_hasher()
-        enckey_hasher = hashutil.key_hasher()
-        f.seek(0)
-        BLOCKSIZE = 64*1024
-        while True:
-            data = f.read(BLOCKSIZE)
-            if not data:
-                break
-            plaintext_hasher.update(data)
-            enckey_hasher.update(data)
-        plaintext_hash = plaintext_hasher.digest()
-        enckey = enckey_hasher.digest()
-
-        # now make a second pass to determine the crypttext_hash. It would be
-        # nice to make this involve fewer passes.
-        crypttext_hasher = hashutil.crypttext_hasher()
-        key = enckey[:16]
-        cryptor = AES.new(key=key, mode=AES.MODE_CTR,
-                          counterstart="\x00"*16)
-        f.seek(0)
-        while True:
-            data = f.read(BLOCKSIZE)
-            if not data:
-                break
-            crypttext_hasher.update(cryptor.encrypt(data))
-        crypttext_hash = crypttext_hasher.digest()
-
-        # and leave the file pointer at the beginning
-        f.seek(0)
-
-        return plaintext_hash, key, crypttext_hash
-
-    def upload(self, f, options={}):
+    def upload(self, uploadable, options={}):
         # this returns the URI
         assert self.parent
         assert self.running
-        f = IUploadable(f)
-        fh = f.get_filehandle()
-        fh.seek(0,2)
-        size = fh.tell()
-        fh.seek(0)
-        if size <= self.URI_LIT_SIZE_THRESHOLD:
-            u = LiteralUploader(self.parent, options)
-            u.set_filehandle(fh)
-        else:
-            u = self.uploader_class(self.parent, options)
-            u.set_filehandle(fh)
-            encoding_parameters = self.parent.get_encoding_parameters()
-            if not encoding_parameters:
-                encoding_parameters = self.DEFAULT_ENCODING_PARAMETERS
-            u.set_params(encoding_parameters)
-            plaintext_hash, key, crypttext_hash = self.compute_id_strings(fh)
-            u.set_encryption_key(key)
-            u.set_id_strings(crypttext_hash, plaintext_hash)
-        d = u.start()
+        uploadable = IUploadable(uploadable)
+        d = uploadable.get_size()
+        def _got_size(size):
+            uploader_class = self.uploader_class
+            if size <= self.URI_LIT_SIZE_THRESHOLD:
+                uploader_class = LiteralUploader
+            uploader = self.uploader_class(self.parent, uploadable, options)
+            uploader.set_params(self.parent.get_encoding_parameters()
+                                or self.DEFAULT_ENCODING_PARAMETERS)
+            return uploader.start()
+        d.addCallback(_got_size)
         def _done(res):
-            f.close_filehandle(fh)
+            uploadable.close()
             return res
         d.addBoth(_done)
         return d
index db8c937c378c462fd801ef63cfed63e6db7a83a5..62042c4b8e59bfccb9b209604bf3778376c6c75b 100644 (file)
@@ -22,6 +22,9 @@ def tagged_pair_hash(tag, val1, val2):
 def tagged_hasher(tag):
     return SHA256.new(netstring(tag))
 
+def storage_index_chk_hash(data):
+    return tagged_hash("allmydata_CHK_storage_index_v1", data)
+
 def block_hash(data):
     return tagged_hash("allmydata_encoded_subshare_v1", data)
 def block_hasher():