finish making the new encoder/decoder/upload/download work
authorBrian Warner <warner@allmydata.com>
Fri, 30 Mar 2007 23:50:50 +0000 (16:50 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 30 Mar 2007 23:50:50 +0000 (16:50 -0700)
src/allmydata/download.py
src/allmydata/encode_new.py
src/allmydata/interfaces.py
src/allmydata/storageserver.py
src/allmydata/test/test_encode.py
src/allmydata/test/test_storage.py [deleted file]
src/allmydata/test/test_upload.py
src/allmydata/upload.py
src/allmydata/uri.py

index bafce789151de177354b7e4ee89f9d3c1c1b16ad..eacc9eff3ae1fae6a1371b9baf162677e2b1d247 100644 (file)
@@ -5,8 +5,7 @@ from twisted.python import log
 from twisted.internet import defer
 from twisted.application import service
 
-from allmydata.util import idlib, bencode, mathutil
-from allmydata.util.deferredutil import DeferredListShouldSucceed
+from allmydata.util import idlib, mathutil
 from allmydata.util.assertutil import _assert
 from allmydata import codec
 from allmydata.Crypto.Cipher import AES
@@ -48,6 +47,7 @@ class Output:
     def finish(self):
         return self.downloadable.finish()
 
+
 class BlockDownloader:
     def __init__(self, bucket, blocknum, parent):
         self.bucket = bucket
@@ -131,18 +131,23 @@ class FileDownloader:
     def __init__(self, client, uri, downloadable):
         self._client = client
         self._downloadable = downloadable
-        (codec_name, codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size) = unpack_uri(uri)
+        (codec_name, codec_params, tail_codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size) = unpack_uri(uri)
         assert isinstance(verifierid, str)
         assert len(verifierid) == 20
         self._verifierid = verifierid
         self._roothash = roothash
-        self._decoder = codec.get_decoder_by_name(codec_name)
-        self._decoder.set_serialized_params(codec_params)
+
+        self._codec = codec.get_decoder_by_name(codec_name)
+        self._codec.set_serialized_params(codec_params)
+        self._tail_codec = codec.get_decoder_by_name(codec_name)
+        self._tail_codec.set_serialized_params(tail_codec_params)
+
+
         self._total_segments = mathutil.div_ceil(size, segment_size)
         self._current_segnum = 0
         self._segment_size = segment_size
         self._size = size
-        self._num_needed_shares = self._decoder.get_needed_shares()
+        self._num_needed_shares = self._codec.get_needed_shares()
 
         key = "\x00" * 16
         self._output = Output(downloadable, key)
@@ -173,14 +178,16 @@ class FileDownloader:
     def _get_all_shareholders(self):
         dl = []
         for (permutedpeerid, peerid, connection) in self._client.get_permuted_peers(self._verifierid):
-            d = connection.callRemote("get_buckets", self._verifierid)
+            d = connection.callRemote("get_service", "storageserver")
+            d.addCallback(lambda ss: ss.callRemote("get_buckets",
+                                                   self._verifierid))
             d.addCallbacks(self._got_response, self._got_error,
                            callbackArgs=(connection,))
             dl.append(d)
         return defer.DeferredList(dl)
 
     def _got_response(self, buckets, connection):
-        for sharenum, bucket in buckets:
+        for sharenum, bucket in buckets.iteritems():
             self._share_buckets.setdefault(sharenum, set()).add(bucket)
         
     def _got_error(self, f):
@@ -193,30 +200,34 @@ class FileDownloader:
         self.active_buckets = {}
         self._output.open()
 
-    def _download_all_segments(self):
-        d = self._download_segment(self._current_segnum)
+    def _download_all_segments(self, res):
+        d = defer.succeed(None)
+        for segnum in range(self._total_segments-1):
+            d.addCallback(self._download_segment, segnum)
+        d.addCallback(self._download_tail_segment, self._total_segments-1)
+        return d
+
+    def _download_segment(self, res, segnum):
+        segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
+        d = segmentdler.start()
+        d.addCallback(lambda (shares, shareids):
+                      self._codec.decode(shares, shareids))
         def _done(res):
-            if self._current_segnum == self._total_segments:
-                return None
-            return self._download_segment(self._current_segnum)
+            for buf in res:
+                self._output.write(buf)
         d.addCallback(_done)
         return d
 
-    def _download_segment(self, segnum):
+    def _download_tail_segment(self, res, segnum):
         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
         d = segmentdler.start()
         d.addCallback(lambda (shares, shareids):
-                      self._decoder.decode(shares, shareids))
+                      self._tail_codec.decode(shares, shareids))
         def _done(res):
-            self._current_segnum += 1
-            if self._current_segnum == self._total_segments:
-                data = ''.join(res)
-                padsize = mathutil.pad_size(self._size, self._segment_size)
-                data = data[:-padsize]
-                self._output.write(data)
-            else:
-                for buf in res:
-                    self._output.write(buf)
+            # trim off any padding added by the upload side
+            data = ''.join(res)
+            tail_size = self._size % self._segment_size
+            self._output.write(data[:tail_size])
         d.addCallback(_done)
         return d
 
@@ -302,8 +313,7 @@ class Downloader(service.MultiService):
         t = IDownloadTarget(t)
         assert t.write
         assert t.close
-        dl = FileDownloader(self.parent, uri)
-        dl.set_download_target(t)
+        dl = FileDownloader(self.parent, uri, t)
         if self.debug:
             dl.debug = True
         d = dl.start()
index 5d0458d72a7c48aa350ac1c417ba066384b77db7..dc218f18f0f9a58486e41fbed16ec90e19a119a2 100644 (file)
@@ -5,6 +5,7 @@ from twisted.internet import defer
 from allmydata.chunk import HashTree, roundup_pow2
 from allmydata.Crypto.Cipher import AES
 from allmydata.util import mathutil, hashutil
+from allmydata.util.assertutil import _assert
 from allmydata.codec import CRSEncoder
 from allmydata.interfaces import IEncoder
 
@@ -88,12 +89,32 @@ class Encoder(object):
         self.required_shares = self.NEEDED_SHARES
 
         self.segment_size = min(2*MiB, self.file_size)
+        # this must be a multiple of self.required_shares
+        self.segment_size = mathutil.next_multiple(self.segment_size,
+                                                   self.required_shares)
         self.setup_codec()
 
     def setup_codec(self):
+        assert self.segment_size % self.required_shares == 0
         self._codec = CRSEncoder()
-        self._codec.set_params(self.segment_size, self.required_shares,
-                               self.num_shares)
+        self._codec.set_params(self.segment_size,
+                               self.required_shares, self.num_shares)
+
+        # the "tail" is the last segment. This segment may or may not be
+        # shorter than all other segments. We use the "tail codec" to handle
+        # it. If the tail is short, we use a different codec instance. In
+        # addition, the tail codec must be fed data which has been padded out
+        # to the right size.
+        self.tail_size = self.file_size % self.segment_size
+        if not self.tail_size:
+            self.tail_size = self.segment_size
+
+        # the tail codec is responsible for encoding tail_size bytes
+        padded_tail_size = mathutil.next_multiple(self.tail_size,
+                                                  self.required_shares)
+        self._tail_codec = CRSEncoder()
+        self._tail_codec.set_params(padded_tail_size,
+                                    self.required_shares, self.num_shares)
 
     def get_share_size(self):
         share_size = mathutil.div_ceil(self.file_size, self.required_shares)
@@ -105,6 +126,11 @@ class Encoder(object):
         return self._codec.get_block_size()
 
     def set_shareholders(self, landlords):
+        assert isinstance(landlords, dict)
+        for k in landlords:
+            # it would be nice to:
+            #assert RIBucketWriter.providedBy(landlords[k])
+            pass
         self.landlords = landlords.copy()
 
     def start(self):
@@ -116,8 +142,11 @@ class Encoder(object):
         self.setup_encryption()
         self.setup_codec()
         d = defer.succeed(None)
-        for i in range(self.num_segments):
+
+        for i in range(self.num_segments-1):
             d.addCallback(lambda res: self.do_segment(i))
+        d.addCallback(lambda res: self.do_tail_segment(self.num_segments-1))
+
         d.addCallback(lambda res: self.send_all_subshare_hash_trees())
         d.addCallback(lambda res: self.send_all_share_hash_trees())
         d.addCallback(lambda res: self.close_all_shareholders())
@@ -137,13 +166,14 @@ class Encoder(object):
 
     def do_segment(self, segnum):
         chunks = []
+        codec = self._codec
         # the ICodecEncoder API wants to receive a total of self.segment_size
         # bytes on each encode() call, broken up into a number of
         # identically-sized pieces. Due to the way the codec algorithm works,
         # these pieces need to be the same size as the share which the codec
         # will generate. Therefore we must feed it with input_piece_size that
         # equals the output share size.
-        input_piece_size = self._codec.get_block_size()
+        input_piece_size = codec.get_block_size()
 
         # as a result, the number of input pieces per encode() call will be
         # equal to the number of required shares with which the codec was
@@ -153,6 +183,21 @@ class Encoder(object):
         # of additional shares which can be substituted if the primary ones
         # are unavailable
 
+        for i in range(self.required_shares):
+            input_piece = self.infile.read(input_piece_size)
+            # non-tail segments should be the full segment size
+            assert len(input_piece) == input_piece_size
+            encrypted_piece = self.cryptor.encrypt(input_piece)
+            chunks.append(encrypted_piece)
+        d = codec.encode(chunks)
+        d.addCallback(self._encoded_segment, segnum)
+        return d
+
+    def do_tail_segment(self, segnum):
+        chunks = []
+        codec = self._tail_codec
+        input_piece_size = codec.get_block_size()
+
         for i in range(self.required_shares):
             input_piece = self.infile.read(input_piece_size)
             if len(input_piece) < input_piece_size:
@@ -160,20 +205,21 @@ class Encoder(object):
                 input_piece += ('\x00' * (input_piece_size - len(input_piece)))
             encrypted_piece = self.cryptor.encrypt(input_piece)
             chunks.append(encrypted_piece)
-        d = self._codec.encode(chunks)
-        d.addCallback(self._encoded_segment)
+        d = codec.encode(chunks)
+        d.addCallback(self._encoded_segment, segnum)
         return d
 
-    def _encoded_segment(self, (shares, shareids)):
+    def _encoded_segment(self, (shares, shareids), segnum):
+        _assert(set(shareids) == set(self.landlords.keys()),
+                shareids=shareids, landlords=self.landlords)
         dl = []
         for i in range(len(shares)):
             subshare = shares[i]
             shareid = shareids[i]
-            d = self.send_subshare(shareid, self.segment_num, subshare)
+            d = self.send_subshare(shareid, segnum, subshare)
             dl.append(d)
             subshare_hash = hashutil.tagged_hash("encoded subshare", subshare)
             self.subshare_hashes[shareid].append(subshare_hash)
-        self.segment_num += 1
         return defer.DeferredList(dl)
 
     def send_subshare(self, shareid, segment_num, subshare):
index cc235779f1e2d34491cf241f3336462fa6284c59..4cb69ba514b6a08cd903227427117fca6e964737 100644 (file)
@@ -9,7 +9,7 @@ Hash = StringConstraint(HASH_SIZE) # binary format 32-byte SHA256 hash
 Nodeid = StringConstraint(20) # binary format 20-byte SHA1 hash
 PBURL = StringConstraint(150)
 Verifierid = StringConstraint(20)
-URI = StringConstraint(100) # kind of arbitrary
+URI = StringConstraint(200) # kind of arbitrary
 ShareData = StringConstraint(100000)
 # these six are here because Foolscap does not yet support the kind of
 # restriction I really want to apply to these.
@@ -38,6 +38,9 @@ class RIClient(RemoteInterface):
 
 class RIBucketWriter(RemoteInterface):
     def put_block(segmentnum=int, data=ShareData):
+        """@param data: For most segments, this data will be 'blocksize'
+        bytes in length. The last segment might be shorter.
+        """
         return None
     
     def put_block_hashes(blockhashes=ListOf(Hash)):
@@ -68,6 +71,9 @@ class RIStorageServer(RemoteInterface):
 
 class RIBucketReader(RemoteInterface):
     def get_block(blocknum=int):
+        """Most blocks will be the same size. The last block might be shorter
+        than the others.
+        """
         return ShareData
     def get_block_hashes():
         return ListOf(Hash)
index bb9622d164c73c9a45aa6f4b38d6c8e52899c3b6..23d3ece6778425438bd5dfc6f11f08e6fb265d31 100644 (file)
@@ -29,6 +29,8 @@ class BucketWriter(Referenceable):
         self.finalhome = finalhome
         self.blocksize = blocksize
         self.closed = False
+        fileutil.make_dirs(incominghome)
+        fileutil.make_dirs(finalhome)
         self._write_file('blocksize', str(blocksize))
 
     def _write_file(self, fname, data):
@@ -51,7 +53,7 @@ class BucketWriter(Referenceable):
         precondition(not self.closed)
         self._write_file('sharehashree', bencode.bencode(sharehashes))
 
-    def close(self):
+    def remote_close(self):
         precondition(not self.closed)
         # TODO assert or check the completeness and consistency of the data that has been written
         fileutil.rename(self.incominghome, self.finalhome)
@@ -72,7 +74,7 @@ class BucketReader(Referenceable):
     def remote_get_block(self, blocknum):
         f = open(os.path.join(self.home, 'data'), 'rb')
         f.seek(self.blocksize * blocknum)
-        return f.read(self.blocksize)
+        return f.read(self.blocksize) # this might be short for the last block
 
     def remote_get_block_hashes(self):
         return str2l(self._read_file('blockhashes'))
@@ -101,8 +103,9 @@ class StorageServer(service.MultiService, Referenceable):
         alreadygot = set()
         bucketwriters = {} # k: shnum, v: BucketWriter
         for shnum in sharenums:
-            incominghome = os.path.join(self.incomingdir, idlib.a2b(verifierid), "%d"%shnum)
-            finalhome = os.path.join(self.storedir, idlib.a2b(verifierid), "%d"%shnum)
+            incominghome = os.path.join(self.incomingdir,
+                                        idlib.b2a(verifierid) +  "%d"%shnum)
+            finalhome = os.path.join(self.storedir, idlib.b2a(verifierid), "%d"%shnum)
             if os.path.exists(incominghome) or os.path.exists(finalhome):
                 alreadygot.add(shnum)
             else:
index d3311ac97a556925e059dec47b6631a650284dc4..e03837592e801e83e1cf403de0b6b5c622f457a9 100644 (file)
@@ -2,6 +2,7 @@
 
 from twisted.trial import unittest
 from twisted.internet import defer
+from foolscap import eventual
 from allmydata import encode_new, download
 from allmydata.uri import pack_uri
 from cStringIO import StringIO
@@ -14,15 +15,39 @@ class MyEncoder(encode_new.Encoder):
                 print " ", [i for i,h in args[0]]
         return defer.succeed(None)
 
-class Encode(unittest.TestCase):
-    def test_1(self):
-        e = MyEncoder()
-        data = StringIO("some data to encode\n")
-        e.setup(data)
-        d = e.start()
+class FakePeer:
+    def __init__(self, mode="good"):
+        self.ss = FakeStorageServer(mode)
+
+    def callRemote(self, methname, *args, **kwargs):
+        def _call():
+            meth = getattr(self, methname)
+            return meth(*args, **kwargs)
+        return defer.maybeDeferred(_call)
+
+    def get_service(self, sname):
+        assert sname == "storageserver"
+        return self.ss
+
+class FakeStorageServer:
+    def __init__(self, mode):
+        self.mode = mode
+    def callRemote(self, methname, *args, **kwargs):
+        def _call():
+            meth = getattr(self, methname)
+            return meth(*args, **kwargs)
+        d = eventual.fireEventually()
+        d.addCallback(lambda res: _call())
         return d
+    def allocate_buckets(self, verifierid, sharenums, shareize, blocksize, canary):
+        if self.mode == "full":
+            return (set(), {},)
+        elif self.mode == "already got them":
+            return (set(sharenums), {},)
+        else:
+            return (set(), dict([(shnum, FakeBucketWriter(),) for shnum in sharenums]),)
 
-class FakePeer:
+class FakeBucketWriter:
     def __init__(self):
         self.blocks = {}
         self.block_hashes = None
@@ -65,7 +90,7 @@ class FakePeer:
         return self.share_hashes
 
 
-class UpDown(unittest.TestCase):
+class Encode(unittest.TestCase):
     def test_send(self):
         e = encode_new.Encoder()
         data = "happy happy joy joy" * 4
@@ -79,7 +104,7 @@ class UpDown(unittest.TestCase):
         shareholders = {}
         all_shareholders = []
         for shnum in range(NUM_SHARES):
-            peer = FakePeer()
+            peer = FakeBucketWriter()
             shareholders[shnum] = peer
             all_shareholders.append(peer)
         e.set_shareholders(shareholders)
@@ -104,20 +129,24 @@ class UpDown(unittest.TestCase):
 
         return d
 
-    def test_send_and_recover(self):
+class Roundtrip(unittest.TestCase):
+    def send_and_recover(self, NUM_SHARES, NUM_PEERS, NUM_SEGMENTS=4):
         e = encode_new.Encoder()
         data = "happy happy joy joy" * 4
         e.setup(StringIO(data))
-        NUM_SHARES = 100
+
         assert e.num_shares == NUM_SHARES # else we'll be completely confused
         e.segment_size = 25 # force use of multiple segments
         e.setup_codec() # need to rebuild the codec for that change
-        NUM_SEGMENTS = 4
+
         assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
         shareholders = {}
         all_shareholders = []
+        all_peers = []
+        for i in range(NUM_PEERS):
+            all_peers.append(FakeBucketWriter())
         for shnum in range(NUM_SHARES):
-            peer = FakePeer()
+            peer = all_peers[shnum % NUM_PEERS]
             shareholders[shnum] = peer
             all_shareholders.append(peer)
         e.set_shareholders(shareholders)
@@ -125,6 +154,7 @@ class UpDown(unittest.TestCase):
         def _uploaded(roothash):
             URI = pack_uri(e._codec.get_encoder_type(),
                            e._codec.get_serialized_params(),
+                           e._tail_codec.get_serialized_params(),
                            "V" * 20,
                            roothash,
                            e.required_shares,
@@ -138,7 +168,7 @@ class UpDown(unittest.TestCase):
             for shnum in range(NUM_SHARES):
                 fd._share_buckets[shnum] = set([all_shareholders[shnum]])
             fd._got_all_shareholders(None)
-            d2 = fd._download_all_segments()
+            d2 = fd._download_all_segments(None)
             d2.addCallback(fd._done)
             return d2
         d.addCallback(_uploaded)
@@ -147,3 +177,7 @@ class UpDown(unittest.TestCase):
         d.addCallback(_downloaded)
 
         return d
+
+    def test_one_share_per_peer(self):
+        return self.send_and_recover(100, 100)
+
diff --git a/src/allmydata/test/test_storage.py b/src/allmydata/test/test_storage.py
deleted file mode 100644 (file)
index 31da06a..0000000
+++ /dev/null
@@ -1,156 +0,0 @@
-
-import os
-import random
-
-from twisted.trial import unittest
-from twisted.application import service
-from twisted.internet import defer
-from foolscap import Tub, Referenceable
-from foolscap.eventual import flushEventualQueue
-
-from allmydata import client
-
-class Canary(Referenceable):
-    pass
-
-class StorageTest(unittest.TestCase):
-
-    def setUp(self):
-        self.svc = service.MultiService()
-        self.node = client.Client('')
-        self.node.setServiceParent(self.svc)
-        self.tub = Tub()
-        self.tub.setServiceParent(self.svc)
-        self.svc.startService()
-        return self.node.when_tub_ready()
-
-    def test_create_bucket(self):
-        """
-        Check that the storage server can return bucket data accurately.
-        """
-        vid = os.urandom(20)
-        bnum = random.randrange(0, 256)
-        data = os.urandom(random.randint(1024, 16384))
-
-        rssd = self.tub.getReference(self.node.my_pburl)
-        def get_storageserver(node):
-            return node.callRemote('get_service', name='storageserver')
-        rssd.addCallback(get_storageserver)
-
-        def create_bucket(storageserver):
-            return storageserver.callRemote('allocate_bucket',
-                                            verifierid=vid,
-                                            bucket_num=bnum,
-                                            size=len(data),
-                                            leaser=self.node.nodeid,
-                                            canary=Canary(),
-                                            )
-        rssd.addCallback(create_bucket)
-
-        def write_to_bucket(bucket):
-            def write_some(junk, bytes):
-                return bucket.callRemote('write', data=bytes)
-            def set_metadata(junk, metadata):
-                return bucket.callRemote('set_metadata', metadata)
-            def finalise(junk):
-                return bucket.callRemote('close')
-            off1 = len(data) / 2
-            off2 = 3 * len(data) / 4
-            d = defer.succeed(None)
-            d.addCallback(write_some, data[:off1])
-            d.addCallback(write_some, data[off1:off2])
-            d.addCallback(set_metadata, "metadata")
-            d.addCallback(write_some, data[off2:])
-            d.addCallback(finalise)
-            return d
-        rssd.addCallback(write_to_bucket)
-
-        def get_node_again(junk):
-            return self.tub.getReference(self.node.my_pburl)
-        rssd.addCallback(get_node_again)
-        rssd.addCallback(get_storageserver)
-
-        def get_buckets(storageserver):
-            return storageserver.callRemote('get_buckets', verifierid=vid)
-        rssd.addCallback(get_buckets)
-
-        def read_buckets(buckets):
-            self.failUnlessEqual(len(buckets), 1)
-            bucket_num, bucket = buckets[0]
-            self.failUnlessEqual(bucket_num, bnum)
-
-            def check_data(bytes_read):
-                self.failUnlessEqual(bytes_read, data)
-            d = bucket.callRemote('read')
-            d.addCallback(check_data)
-
-            def check_metadata(metadata):
-                self.failUnlessEqual(metadata, 'metadata')
-            d.addCallback(lambda res: bucket.callRemote('get_metadata'))
-            d.addCallback(check_metadata)
-            return d
-        rssd.addCallback(read_buckets)
-
-        return rssd
-
-    def test_overwrite(self):
-        """
-        Check that the storage server rejects an attempt to write too much data.
-        """
-        vid = os.urandom(20)
-        bnum = random.randrange(0, 256)
-        data = os.urandom(random.randint(1024, 16384))
-
-        rssd = self.tub.getReference(self.node.my_pburl)
-        def get_storageserver(node):
-            return node.callRemote('get_service', name='storageserver')
-        rssd.addCallback(get_storageserver)
-
-        def create_bucket(storageserver):
-            return storageserver.callRemote('allocate_bucket',
-                                            verifierid=vid,
-                                            bucket_num=bnum,
-                                            size=len(data),
-                                            leaser=self.node.nodeid,
-                                            canary=Canary(),
-                                            )
-        rssd.addCallback(create_bucket)
-
-        def write_to_bucket(bucket):
-            def write_some(junk, bytes):
-                return bucket.callRemote('write', data=bytes)
-            def finalise(junk):
-                return bucket.callRemote('close')
-            off1 = len(data) / 2
-            off2 = 3 * len(data) / 4
-            d = defer.succeed(None)
-            d.addCallback(write_some, data[:off1])
-            d.addCallback(write_some, data[off1:off2])
-            d.addCallback(write_some, data[off2:])
-            # and then overwrite
-            d.addCallback(write_some, data[off1:off2])
-            d.addCallback(finalise)
-            return d
-        rssd.addCallback(write_to_bucket)
-
-        self.deferredShouldFail(rssd, ftype=AssertionError)
-        return rssd
-
-    def deferredShouldFail(self, d, ftype=None, checker=None):
-
-        def _worked(res):
-            self.fail("hey, this was supposed to fail, not return %s" % res)
-        if not ftype and not checker:
-            d.addCallbacks(_worked,
-                           lambda f: None)
-        elif ftype and not checker:
-            d.addCallbacks(_worked,
-                           lambda f: f.trap(ftype) or None)
-        else:
-            d.addCallbacks(_worked,
-                           checker)
-
-    def tearDown(self):
-        d = self.svc.stopService()
-        d.addCallback(lambda res: flushEventualQueue())
-        return d
index fbf6327a9b22ae4e824a56495796b56cd20d3162..5eb9b540a77e760c89cf4193c54d454da7b5cad3 100644 (file)
@@ -1,40 +1,18 @@
 
 from twisted.trial import unittest
-from twisted.python import log
 from twisted.python.failure import Failure
-from twisted.internet import defer
 from cStringIO import StringIO
 
-from foolscap import eventual
-
 from allmydata import upload
 from allmydata.uri import unpack_uri
 
 from test_encode import FakePeer
 
-class FakeStorageServer:
-    def __init__(self, mode):
-        self.mode = mode
-    def callRemote(self, methname, *args, **kwargs):
-        def _call():
-            meth = getattr(self, methname)
-            return meth(*args, **kwargs)
-        d = eventual.fireEventually()
-        d.addCallback(lambda res: _call())
-        return d
-    def allocate_buckets(self, verifierid, sharenums, shareize, blocksize, canary):
-        if self.mode == "full":
-            return (set(), {},)
-        elif self.mode == "already got them":
-            return (set(sharenums), {},)
-        else:
-            return (set(), dict([(shnum, FakePeer(),) for shnum in sharenums]),)
-
 class FakeClient:
     def __init__(self, mode="good"):
         self.mode = mode
     def get_permuted_peers(self, verifierid):
-        return [ ("%20d"%fakeid, "%20d"%fakeid, FakeStorageServer(self.mode),) for fakeid in range(50) ]
+        return [ ("%20d"%fakeid, "%20d"%fakeid, FakePeer(self.mode),) for fakeid in range(50) ]
 
 class GoodServer(unittest.TestCase):
     def setUp(self):
@@ -46,13 +24,11 @@ class GoodServer(unittest.TestCase):
     def _check(self, uri):
         self.failUnless(isinstance(uri, str))
         self.failUnless(uri.startswith("URI:"))
-        codec_name, codec_params, verifierid = unpack_uri(uri)
+        codec_name, codec_params, tail_codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size = unpack_uri(uri)
         self.failUnless(isinstance(verifierid, str))
         self.failUnlessEqual(len(verifierid), 20)
         self.failUnless(isinstance(codec_params, str))
-        peers = self.node.peers
-        self.failUnlessEqual(peers[0].allocated_size,
-                             len(peers[0].data))
+
     def testData(self):
         data = "This is some data to upload"
         d = self.u.upload_data(data)
index 7559319d5d47aba1bb6961948113ddd6850fa951..7d9a28b0b1370c3172995a8dce989a1884049d83 100644 (file)
@@ -5,7 +5,6 @@ from twisted.application import service
 from foolscap import Referenceable
 
 from allmydata.util import idlib
-from allmydata.util.assertutil import _assert
 from allmydata import encode_new
 from allmydata.uri import pack_uri
 from allmydata.interfaces import IUploadable, IUploader
@@ -28,16 +27,26 @@ class PeerTracker:
     def __init__(self, peerid, permutedid, connection, sharesize, blocksize, verifierid):
         self.peerid = peerid
         self.permutedid = permutedid
-        self.connection = connection
+        self.connection = connection # to an RIClient
         self.buckets = {} # k: shareid, v: IRemoteBucketWriter
         self.sharesize = sharesize
         self.blocksize = blocksize
         self.verifierid = verifierid
+        self._storageserver = None
 
     def query(self, sharenums):
-        d = self.connection.callRemote("allocate_buckets", self.verifierid,
-                                       sharenums, self.sharesize,
-                                       self.blocksize, canary=Referenceable())
+        if not self._storageserver:
+            d = self.connection.callRemote("get_service", "storageserver")
+            d.addCallback(self._got_storageserver)
+            d.addCallback(lambda res: self._query(sharenums))
+            return d
+        return self._query(sharenums)
+    def _got_storageserver(self, storageserver):
+        self._storageserver = storageserver
+    def _query(self, sharenums):
+        d = self._storageserver.callRemote("allocate_buckets", self.verifierid,
+                                           sharenums, self.sharesize,
+                                           self.blocksize, canary=Referenceable())
         d.addCallback(self._got_reply)
         return d
         
@@ -194,7 +203,7 @@ class FileUploader:
         self.unallocated_sharenums -= allocated
 
         if allocated:
-            self.usable_peers.add(peer)
+            self.used_peers.add(peer)
 
         if shares_we_requested - alreadygot - allocated:
             log.msg("%s._got_response(%s, %s, %s): self.unallocated_sharenums: %s, unhandled: %s HE'S FULL" % (self, (alreadygot, allocated), peer, shares_we_requested, self.unallocated_sharenums, shares_we_requested - alreadygot - allocated))
@@ -222,7 +231,11 @@ class FileUploader:
     def _compute_uri(self, roothash):
         codec_type = self._encoder._codec.get_encoder_type()
         codec_params = self._encoder._codec.get_serialized_params()
-        return pack_uri(codec_type, codec_params, self._verifierid, roothash, self.needed_shares, self.total_shares, self._size, self._encoder.segment_size)
+        tail_codec_params = self._encoder._tail_codec.get_serialized_params()
+        return pack_uri(codec_type, codec_params, tail_codec_params,
+                        self._verifierid,
+                        roothash, self.needed_shares, self.total_shares,
+                        self._size, self._encoder.segment_size)
 
 
 def netstring(s):
index ed43eb7eca734021c3fd64b6aa31e9a6a6cb2aa6..a0f77fdd0d423441ff2750bd06d0452512f9786d 100644 (file)
@@ -5,26 +5,28 @@ from allmydata.util import idlib
 # enough information to retrieve and validate the contents. It shall be
 # expressed in a limited character set (namely [TODO]).
 
-def pack_uri(codec_name, codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size):
+def pack_uri(codec_name, codec_params, tail_codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size):
     assert isinstance(codec_name, str)
     assert len(codec_name) < 10
     assert ":" not in codec_name
     assert isinstance(codec_params, str)
     assert ":" not in codec_params
+    assert isinstance(tail_codec_params, str)
+    assert ":" not in tail_codec_params
     assert isinstance(verifierid, str)
     assert len(verifierid) == 20 # sha1 hash
-    return "URI:%s:%s:%s:%s:%s:%s:%s:%s" % (codec_name, codec_params, idlib.b2a(verifierid), idlib.b2a(roothash), needed_shares, total_shares, size, segment_size)
+    return "URI:%s:%s:%s:%s:%s:%s:%s:%s:%s" % (codec_name, codec_params, tail_codec_params, idlib.b2a(verifierid), idlib.b2a(roothash), needed_shares, total_shares, size, segment_size)
 
 
 def unpack_uri(uri):
     assert uri.startswith("URI:")
-    header, codec_name, codec_params, verifierid_s, roothash_s, needed_shares_s, total_shares_s, size_s, segment_size_s = uri.split(":")
+    header, codec_name, codec_params, tail_codec_params, verifierid_s, roothash_s, needed_shares_s, total_shares_s, size_s, segment_size_s = uri.split(":")
     verifierid = idlib.a2b(verifierid_s)
     roothash = idlib.a2b(roothash_s)
     needed_shares = int(needed_shares_s)
     total_shares = int(total_shares_s)
     size = int(size_s)
     segment_size = int(segment_size_s)
-    return codec_name, codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size
+    return codec_name, codec_params, tail_codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size