From: Brian Warner Date: Fri, 30 Mar 2007 20:20:01 +0000 (-0700) Subject: add new test for doing an encode/decode round trip, and make it almost work X-Git-Url: https://git.rkrishnan.org/somewhere?a=commitdiff_plain;h=234b2f354e59503edb05b5df3fc06c331fcafc03;p=tahoe-lafs%2Ftahoe-lafs.git add new test for doing an encode/decode round trip, and make it almost work --- diff --git a/src/allmydata/download.py b/src/allmydata/download.py index d7ed7d9b..bafce789 100644 --- a/src/allmydata/download.py +++ b/src/allmydata/download.py @@ -7,6 +7,7 @@ from twisted.application import service from allmydata.util import idlib, bencode, mathutil from allmydata.util.deferredutil import DeferredListShouldSucceed +from allmydata.util.assertutil import _assert from allmydata import codec from allmydata.Crypto.Cipher import AES from allmydata.uri import unpack_uri @@ -27,13 +28,24 @@ class Output: counterstart="\x00"*16) self._verifierid_hasher = sha.new(netstring("allmydata_v1_verifierid")) self._fileid_hasher = sha.new(netstring("allmydata_v1_fileid")) + self.length = 0 + + def open(self): + self.downloadable.open() + def write(self, crypttext): + self.length += len(crypttext) self._verifierid_hasher.update(crypttext) plaintext = self._decryptor.decrypt(crypttext) self._fileid_hasher.update(plaintext) self.downloadable.write(plaintext) - def finish(self): + + def close(self): + self.verifierid = self._verifierid_hasher.digest() + self.fileid = self._fileid_hasher.digest() self.downloadable.close() + + def finish(self): return self.downloadable.finish() class BlockDownloader: @@ -51,10 +63,12 @@ class BlockDownloader: self.parent.hold_block(self.blocknum, data) def _got_block_error(self, f): + log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f)) self.parent.bucket_failed(self.blocknum, self.bucket) class SegmentDownloader: - def __init__(self, segmentnumber, needed_shares): + def __init__(self, parent, segmentnumber, needed_shares): + self.parent = parent self.segmentnumber = segmentnumber self.needed_blocks = needed_shares self.blocks = {} # k: blocknum, v: data @@ -66,7 +80,14 @@ class SegmentDownloader: d = self._try() def _done(res): if len(self.blocks) >= self.needed_blocks: - return self.blocks + # we only need self.needed_blocks blocks + # we want to get the smallest blockids, because they are + # more likely to be fast "primary blocks" + blockids = sorted(self.blocks.keys())[:self.needed_blocks] + blocks = [] + for blocknum in blockids: + blocks.append(self.blocks[blocknum]) + return (blocks, blockids) else: return self._download() d.addCallback(_done) @@ -79,14 +100,19 @@ class SegmentDownloader: if not otherblocknums: raise NotEnoughPeersError blocknum = random.choice(otherblocknums) - self.parent.active_buckets[blocknum] = random.choice(self.parent._share_buckets[blocknum]) + bucket = random.choice(list(self.parent._share_buckets[blocknum])) + self.parent.active_buckets[blocknum] = bucket # Now we have enough buckets, in self.parent.active_buckets. - l = [] + + # in test cases, bd.start might mutate active_buckets right away, so + # we need to put off calling start() until we've iterated all the way + # through it + downloaders = [] for blocknum, bucket in self.parent.active_buckets.iteritems(): bd = BlockDownloader(bucket, blocknum, self) - d = bd.start(self.segmentnumber) - l.append(d) + downloaders.append(bd) + l = [bd.start(self.segmentnumber) for bd in downloaders] return defer.DeferredList(l) def hold_block(self, blocknum, data): @@ -115,7 +141,11 @@ class FileDownloader: self._total_segments = mathutil.div_ceil(size, segment_size) self._current_segnum = 0 self._segment_size = segment_size - self._needed_shares = self._decoder.get_needed_shares() + self._size = size + self._num_needed_shares = self._decoder.get_needed_shares() + + key = "\x00" * 16 + self._output = Output(downloadable, key) # future: # self._share_hash_tree = ?? @@ -134,9 +164,6 @@ class FileDownloader: self.active_buckets = {} # k: shnum, v: bucket self._share_buckets = {} # k: shnum, v: set of buckets - key = "\x00" * 16 - self._output = Output(self._downloadable, key) - d = defer.maybeDeferred(self._get_all_shareholders) d.addCallback(self._got_all_shareholders) d.addCallback(self._download_all_segments) @@ -160,11 +187,12 @@ class FileDownloader: self._client.log("Somebody failed. -- %s" % (f,)) def _got_all_shareholders(self, res): - if len(self._share_buckets) < self._needed_shares: + if len(self._share_buckets) < self._num_needed_shares: raise NotEnoughPeersError self.active_buckets = {} - + self._output.open() + def _download_all_segments(self): d = self._download_segment(self._current_segnum) def _done(res): @@ -175,74 +203,33 @@ class FileDownloader: return d def _download_segment(self, segnum): - segmentdler = SegmentDownloader(segnum, self._needed_shares) + segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares) d = segmentdler.start() - d.addCallback(self._decoder.decode) + d.addCallback(lambda (shares, shareids): + self._decoder.decode(shares, shareids)) def _done(res): self._current_segnum += 1 if self._current_segnum == self._total_segments: data = ''.join(res) padsize = mathutil.pad_size(self._size, self._segment_size) data = data[:-padsize] - self.output.write(data) + self._output.write(data) else: for buf in res: - self.output.write(buf) + self._output.write(buf) d.addCallback(_done) return d def _done(self, res): + self._output.close() + #print "VERIFIERID: %s" % idlib.b2a(self._output.verifierid) + #print "FILEID: %s" % idlib.b2a(self._output.fileid) + #assert self._verifierid == self._output.verifierid + #assert self._fileid = self._output.fileid + _assert(self._output.length == self._size, + got=self._output.length, expected=self._size) return self._output.finish() - - def _write_data(self, data): - self._verifierid_hasher.update(data) - - - -# old stuff - def _got_all_peers(self, res): - all_buckets = [] - for peerid, buckets in self.landlords: - all_buckets.extend(buckets) - # TODO: try to avoid pulling multiple shares from the same peer - all_buckets = all_buckets[:self.needed_shares] - # retrieve all shares - dl = [] - shares = [] - shareids = [] - for (bucket_num, bucket) in all_buckets: - d0 = bucket.callRemote("get_metadata") - d1 = bucket.callRemote("read") - d2 = DeferredListShouldSucceed([d0, d1]) - def _got(res): - shareid_s, sharedata = res - shareid = bencode.bdecode(shareid_s) - shares.append(sharedata) - shareids.append(shareid) - d2.addCallback(_got) - dl.append(d2) - d = DeferredListShouldSucceed(dl) - - d.addCallback(lambda res: self._decoder.decode(shares, shareids)) - - def _write(decoded_shares): - data = "".join(decoded_shares) - self._target.open() - hasher = sha.new(netstring("allmydata_v1_verifierid")) - hasher.update(data) - vid = hasher.digest() - assert self._verifierid == vid, "%s != %s" % (idlib.b2a(self._verifierid), idlib.b2a(vid)) - self._target.write(data) - d.addCallback(_write) - def _done(res): - self._target.close() - return self._target.finish() - def _fail(res): - self._target.fail() - return res - d.addCallbacks(_done, _fail) - return d def netstring(s): return "%d:%s," % (len(s), s) diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index 0eb54222..d3311ac9 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -2,7 +2,8 @@ from twisted.trial import unittest from twisted.internet import defer -from allmydata import encode_new +from allmydata import encode_new, download +from allmydata.uri import pack_uri from cStringIO import StringIO class MyEncoder(encode_new.Encoder): @@ -24,8 +25,8 @@ class Encode(unittest.TestCase): class FakePeer: def __init__(self): self.blocks = {} - self.blockhashes = None - self.sharehashes = None + self.block_hashes = None + self.share_hashes = None self.closed = False def callRemote(self, methname, *args, **kwargs): @@ -41,19 +42,29 @@ class FakePeer: def put_block_hashes(self, blockhashes): assert not self.closed - assert self.blockhashes is None - self.blockhashes = blockhashes + assert self.block_hashes is None + self.block_hashes = blockhashes def put_share_hashes(self, sharehashes): assert not self.closed - assert self.sharehashes is None - self.sharehashes = sharehashes + assert self.share_hashes is None + self.share_hashes = sharehashes def close(self): assert not self.closed self.closed = True + def get_block(self, blocknum): + assert isinstance(blocknum, int) + return self.blocks[blocknum] + + def get_block_hashes(self): + return self.block_hashes + def get_share_hashes(self): + return self.share_hashes + + class UpDown(unittest.TestCase): def test_send(self): e = encode_new.Encoder() @@ -79,16 +90,60 @@ class UpDown(unittest.TestCase): for i,peer in enumerate(all_shareholders): self.failUnless(peer.closed) self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS) - #self.failUnlessEqual(len(peer.blockhashes), NUM_SEGMENTS) + #self.failUnlessEqual(len(peer.block_hashes), NUM_SEGMENTS) # that isn't true: each peer gets a full tree, so it's more # like 2n-1 but with rounding to a power of two - for h in peer.blockhashes: + for h in peer.block_hashes: self.failUnlessEqual(len(h), 32) - #self.failUnlessEqual(len(peer.sharehashes), NUM_SHARES) + #self.failUnlessEqual(len(peer.share_hashes), NUM_SHARES) # that isn't true: each peer only gets the chain they need - for (hashnum, h) in peer.sharehashes: + for (hashnum, h) in peer.share_hashes: self.failUnless(isinstance(hashnum, int)) self.failUnlessEqual(len(h), 32) d.addCallback(_check) return d + + def test_send_and_recover(self): + e = encode_new.Encoder() + data = "happy happy joy joy" * 4 + e.setup(StringIO(data)) + NUM_SHARES = 100 + assert e.num_shares == NUM_SHARES # else we'll be completely confused + e.segment_size = 25 # force use of multiple segments + e.setup_codec() # need to rebuild the codec for that change + NUM_SEGMENTS = 4 + assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size + shareholders = {} + all_shareholders = [] + for shnum in range(NUM_SHARES): + peer = FakePeer() + shareholders[shnum] = peer + all_shareholders.append(peer) + e.set_shareholders(shareholders) + d = e.start() + def _uploaded(roothash): + URI = pack_uri(e._codec.get_encoder_type(), + e._codec.get_serialized_params(), + "V" * 20, + roothash, + e.required_shares, + e.num_shares, + e.file_size, + e.segment_size) + client = None + target = download.Data() + fd = download.FileDownloader(client, URI, target) + fd._share_buckets = {} + for shnum in range(NUM_SHARES): + fd._share_buckets[shnum] = set([all_shareholders[shnum]]) + fd._got_all_shareholders(None) + d2 = fd._download_all_segments() + d2.addCallback(fd._done) + return d2 + d.addCallback(_uploaded) + def _downloaded(newdata): + self.failUnless(newdata == data) + d.addCallback(_downloaded) + + return d