From 234b2f354e59503edb05b5df3fc06c331fcafc03 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Fri, 30 Mar 2007 13:20:01 -0700
Subject: [PATCH] add new test for doing an encode/decode round trip, and make
 it almost work

---
 src/allmydata/download.py         | 119 +++++++++++++-----------------
 src/allmydata/test/test_encode.py |  77 ++++++++++++++++---
 2 files changed, 119 insertions(+), 77 deletions(-)

diff --git a/src/allmydata/download.py b/src/allmydata/download.py
index d7ed7d9b..bafce789 100644
--- a/src/allmydata/download.py
+++ b/src/allmydata/download.py
@@ -7,6 +7,7 @@ from twisted.application import service
 
 from allmydata.util import idlib, bencode, mathutil
 from allmydata.util.deferredutil import DeferredListShouldSucceed
+from allmydata.util.assertutil import _assert
 from allmydata import codec
 from allmydata.Crypto.Cipher import AES
 from allmydata.uri import unpack_uri
@@ -27,13 +28,24 @@ class Output:
                                   counterstart="\x00"*16)
         self._verifierid_hasher = sha.new(netstring("allmydata_v1_verifierid"))
         self._fileid_hasher = sha.new(netstring("allmydata_v1_fileid"))
+        self.length = 0
+
+    def open(self):
+        self.downloadable.open()
+
     def write(self, crypttext):
+        self.length += len(crypttext)
         self._verifierid_hasher.update(crypttext)
         plaintext = self._decryptor.decrypt(crypttext)
         self._fileid_hasher.update(plaintext)
         self.downloadable.write(plaintext)
-    def finish(self):
+
+    def close(self):
+        self.verifierid = self._verifierid_hasher.digest()
+        self.fileid = self._fileid_hasher.digest()
         self.downloadable.close()
+
+    def finish(self):
         return self.downloadable.finish()
 
 class BlockDownloader:
@@ -51,10 +63,12 @@ class BlockDownloader:
         self.parent.hold_block(self.blocknum, data)
 
     def _got_block_error(self, f):
+        log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f))
         self.parent.bucket_failed(self.blocknum, self.bucket)
 
 class SegmentDownloader:
-    def __init__(self, segmentnumber, needed_shares):
+    def __init__(self, parent, segmentnumber, needed_shares):
+        self.parent = parent
         self.segmentnumber = segmentnumber
         self.needed_blocks = needed_shares
         self.blocks = {} # k: blocknum, v: data
@@ -66,7 +80,14 @@ class SegmentDownloader:
         d = self._try()
         def _done(res):
             if len(self.blocks) >= self.needed_blocks:
-                return self.blocks
+                # we only need self.needed_blocks blocks
+                # we want to get the smallest blockids, because they are
+                # more likely to be fast "primary blocks"
+                blockids = sorted(self.blocks.keys())[:self.needed_blocks]
+                blocks = []
+                for blocknum in blockids:
+                    blocks.append(self.blocks[blocknum])
+                return (blocks, blockids)
             else:
                 return self._download()
         d.addCallback(_done)
@@ -79,14 +100,19 @@ class SegmentDownloader:
             if not otherblocknums:
                 raise NotEnoughPeersError
             blocknum = random.choice(otherblocknums)
-            self.parent.active_buckets[blocknum] = random.choice(self.parent._share_buckets[blocknum])
+            bucket = random.choice(list(self.parent._share_buckets[blocknum]))
+            self.parent.active_buckets[blocknum] = bucket
 
         # Now we have enough buckets, in self.parent.active_buckets.
-        l = []
+
+        # in test cases, bd.start might mutate active_buckets right away, so
+        # we need to put off calling start() until we've iterated all the way
+        # through it
+        downloaders = []
         for blocknum, bucket in self.parent.active_buckets.iteritems():
             bd = BlockDownloader(bucket, blocknum, self)
-            d = bd.start(self.segmentnumber)
-            l.append(d)
+            downloaders.append(bd)
+        l = [bd.start(self.segmentnumber) for bd in downloaders]
         return defer.DeferredList(l)
 
     def hold_block(self, blocknum, data):
@@ -115,7 +141,11 @@ class FileDownloader:
         self._total_segments = mathutil.div_ceil(size, segment_size)
         self._current_segnum = 0
         self._segment_size = segment_size
-        self._needed_shares = self._decoder.get_needed_shares()
+        self._size = size
+        self._num_needed_shares = self._decoder.get_needed_shares()
+
+        key = "\x00" * 16
+        self._output = Output(downloadable, key)
 
         # future:
         # self._share_hash_tree = ??
@@ -134,9 +164,6 @@ class FileDownloader:
         self.active_buckets = {} # k: shnum, v: bucket
         self._share_buckets = {} # k: shnum, v: set of buckets
 
-        key = "\x00" * 16
-        self._output = Output(self._downloadable, key)
- 
         d = defer.maybeDeferred(self._get_all_shareholders)
         d.addCallback(self._got_all_shareholders)
         d.addCallback(self._download_all_segments)
@@ -160,11 +187,12 @@ class FileDownloader:
         self._client.log("Somebody failed. -- %s" % (f,))
 
     def _got_all_shareholders(self, res):
-        if len(self._share_buckets) < self._needed_shares:
+        if len(self._share_buckets) < self._num_needed_shares:
             raise NotEnoughPeersError
 
         self.active_buckets = {}
-        
+        self._output.open()
+
     def _download_all_segments(self):
         d = self._download_segment(self._current_segnum)
         def _done(res):
@@ -175,74 +203,33 @@ class FileDownloader:
         return d
 
     def _download_segment(self, segnum):
-        segmentdler = SegmentDownloader(segnum, self._needed_shares)
+        segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
         d = segmentdler.start()
-        d.addCallback(self._decoder.decode)
+        d.addCallback(lambda (shares, shareids):
+                      self._decoder.decode(shares, shareids))
         def _done(res):
             self._current_segnum += 1
             if self._current_segnum == self._total_segments:
                 data = ''.join(res)
                 padsize = mathutil.pad_size(self._size, self._segment_size)
                 data = data[:-padsize]
-                self.output.write(data)
+                self._output.write(data)
             else:
                 for buf in res:
-                    self.output.write(buf)
+                    self._output.write(buf)
         d.addCallback(_done)
         return d
 
     def _done(self, res):
+        self._output.close()
+        #print "VERIFIERID: %s" % idlib.b2a(self._output.verifierid)
+        #print "FILEID: %s" % idlib.b2a(self._output.fileid)
+        #assert self._verifierid == self._output.verifierid
+        #assert self._fileid = self._output.fileid
+        _assert(self._output.length == self._size,
+                got=self._output.length, expected=self._size)
         return self._output.finish()
-        
-    def _write_data(self, data):
-        self._verifierid_hasher.update(data)
-        
-        
-
-# old stuff
-    def _got_all_peers(self, res):
-        all_buckets = []
-        for peerid, buckets in self.landlords:
-            all_buckets.extend(buckets)
-        # TODO: try to avoid pulling multiple shares from the same peer
-        all_buckets = all_buckets[:self.needed_shares]
-        # retrieve all shares
-        dl = []
-        shares = []
-        shareids = []
-        for (bucket_num, bucket) in all_buckets:
-            d0 = bucket.callRemote("get_metadata")
-            d1 = bucket.callRemote("read")
-            d2 = DeferredListShouldSucceed([d0, d1])
-            def _got(res):
-                shareid_s, sharedata = res
-                shareid = bencode.bdecode(shareid_s)
-                shares.append(sharedata)
-                shareids.append(shareid)
-            d2.addCallback(_got)
-            dl.append(d2)
-        d = DeferredListShouldSucceed(dl)
-
-        d.addCallback(lambda res: self._decoder.decode(shares, shareids))
-
-        def _write(decoded_shares):
-            data = "".join(decoded_shares)
-            self._target.open()
-            hasher = sha.new(netstring("allmydata_v1_verifierid"))
-            hasher.update(data)
-            vid = hasher.digest()
-            assert self._verifierid == vid, "%s != %s" % (idlib.b2a(self._verifierid), idlib.b2a(vid))
-            self._target.write(data)
-        d.addCallback(_write)
 
-        def _done(res):
-            self._target.close()
-            return self._target.finish()
-        def _fail(res):
-            self._target.fail()
-            return res
-        d.addCallbacks(_done, _fail)
-        return d
 
 def netstring(s):
     return "%d:%s," % (len(s), s)
diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py
index 0eb54222..d3311ac9 100644
--- a/src/allmydata/test/test_encode.py
+++ b/src/allmydata/test/test_encode.py
@@ -2,7 +2,8 @@
 
 from twisted.trial import unittest
 from twisted.internet import defer
-from allmydata import encode_new
+from allmydata import encode_new, download
+from allmydata.uri import pack_uri
 from cStringIO import StringIO
 
 class MyEncoder(encode_new.Encoder):
@@ -24,8 +25,8 @@ class Encode(unittest.TestCase):
 class FakePeer:
     def __init__(self):
         self.blocks = {}
-        self.blockhashes = None
-        self.sharehashes = None
+        self.block_hashes = None
+        self.share_hashes = None
         self.closed = False
 
     def callRemote(self, methname, *args, **kwargs):
@@ -41,19 +42,29 @@ class FakePeer:
     
     def put_block_hashes(self, blockhashes):
         assert not self.closed
-        assert self.blockhashes is None
-        self.blockhashes = blockhashes
+        assert self.block_hashes is None
+        self.block_hashes = blockhashes
         
     def put_share_hashes(self, sharehashes):
         assert not self.closed
-        assert self.sharehashes is None
-        self.sharehashes = sharehashes
+        assert self.share_hashes is None
+        self.share_hashes = sharehashes
 
     def close(self):
         assert not self.closed
         self.closed = True
 
 
+    def get_block(self, blocknum):
+        assert isinstance(blocknum, int)
+        return self.blocks[blocknum]
+
+    def get_block_hashes(self):
+        return self.block_hashes
+    def get_share_hashes(self):
+        return self.share_hashes
+
+
 class UpDown(unittest.TestCase):
     def test_send(self):
         e = encode_new.Encoder()
@@ -79,16 +90,60 @@ class UpDown(unittest.TestCase):
             for i,peer in enumerate(all_shareholders):
                 self.failUnless(peer.closed)
                 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
-                #self.failUnlessEqual(len(peer.blockhashes), NUM_SEGMENTS)
+                #self.failUnlessEqual(len(peer.block_hashes), NUM_SEGMENTS)
                 # that isn't true: each peer gets a full tree, so it's more
                 # like 2n-1 but with rounding to a power of two
-                for h in peer.blockhashes:
+                for h in peer.block_hashes:
                     self.failUnlessEqual(len(h), 32)
-                #self.failUnlessEqual(len(peer.sharehashes), NUM_SHARES)
+                #self.failUnlessEqual(len(peer.share_hashes), NUM_SHARES)
                 # that isn't true: each peer only gets the chain they need
-                for (hashnum, h) in peer.sharehashes:
+                for (hashnum, h) in peer.share_hashes:
                     self.failUnless(isinstance(hashnum, int))
                     self.failUnlessEqual(len(h), 32)
         d.addCallback(_check)
 
         return d
+
+    def test_send_and_recover(self):
+        e = encode_new.Encoder()
+        data = "happy happy joy joy" * 4
+        e.setup(StringIO(data))
+        NUM_SHARES = 100
+        assert e.num_shares == NUM_SHARES # else we'll be completely confused
+        e.segment_size = 25 # force use of multiple segments
+        e.setup_codec() # need to rebuild the codec for that change
+        NUM_SEGMENTS = 4
+        assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
+        shareholders = {}
+        all_shareholders = []
+        for shnum in range(NUM_SHARES):
+            peer = FakePeer()
+            shareholders[shnum] = peer
+            all_shareholders.append(peer)
+        e.set_shareholders(shareholders)
+        d = e.start()
+        def _uploaded(roothash):
+            URI = pack_uri(e._codec.get_encoder_type(),
+                           e._codec.get_serialized_params(),
+                           "V" * 20,
+                           roothash,
+                           e.required_shares,
+                           e.num_shares,
+                           e.file_size,
+                           e.segment_size)
+            client = None
+            target = download.Data()
+            fd = download.FileDownloader(client, URI, target)
+            fd._share_buckets = {}
+            for shnum in range(NUM_SHARES):
+                fd._share_buckets[shnum] = set([all_shareholders[shnum]])
+            fd._got_all_shareholders(None)
+            d2 = fd._download_all_segments()
+            d2.addCallback(fd._done)
+            return d2
+        d.addCallback(_uploaded)
+        def _downloaded(newdata):
+            self.failUnless(newdata == data)
+        d.addCallback(_downloaded)
+
+        return d
-- 
2.45.2