from allmydata.util import idlib, bencode, mathutil
from allmydata.util.deferredutil import DeferredListShouldSucceed
+from allmydata.util.assertutil import _assert
from allmydata import codec
from allmydata.Crypto.Cipher import AES
from allmydata.uri import unpack_uri
counterstart="\x00"*16)
self._verifierid_hasher = sha.new(netstring("allmydata_v1_verifierid"))
self._fileid_hasher = sha.new(netstring("allmydata_v1_fileid"))
+ self.length = 0
+
+ def open(self):
+ self.downloadable.open()
+
def write(self, crypttext):
+ self.length += len(crypttext)
self._verifierid_hasher.update(crypttext)
plaintext = self._decryptor.decrypt(crypttext)
self._fileid_hasher.update(plaintext)
self.downloadable.write(plaintext)
- def finish(self):
+
+ def close(self):
+ self.verifierid = self._verifierid_hasher.digest()
+ self.fileid = self._fileid_hasher.digest()
self.downloadable.close()
+
+ def finish(self):
return self.downloadable.finish()
class BlockDownloader:
self.parent.hold_block(self.blocknum, data)
def _got_block_error(self, f):
+ log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f))
self.parent.bucket_failed(self.blocknum, self.bucket)
class SegmentDownloader:
- def __init__(self, segmentnumber, needed_shares):
+ def __init__(self, parent, segmentnumber, needed_shares):
+ self.parent = parent
self.segmentnumber = segmentnumber
self.needed_blocks = needed_shares
self.blocks = {} # k: blocknum, v: data
d = self._try()
def _done(res):
if len(self.blocks) >= self.needed_blocks:
- return self.blocks
+ # we only need self.needed_blocks blocks
+ # we want to get the smallest blockids, because they are
+ # more likely to be fast "primary blocks"
+ blockids = sorted(self.blocks.keys())[:self.needed_blocks]
+ blocks = []
+ for blocknum in blockids:
+ blocks.append(self.blocks[blocknum])
+ return (blocks, blockids)
else:
return self._download()
d.addCallback(_done)
if not otherblocknums:
raise NotEnoughPeersError
blocknum = random.choice(otherblocknums)
- self.parent.active_buckets[blocknum] = random.choice(self.parent._share_buckets[blocknum])
+ bucket = random.choice(list(self.parent._share_buckets[blocknum]))
+ self.parent.active_buckets[blocknum] = bucket
# Now we have enough buckets, in self.parent.active_buckets.
- l = []
+
+ # in test cases, bd.start might mutate active_buckets right away, so
+ # we need to put off calling start() until we've iterated all the way
+ # through it
+ downloaders = []
for blocknum, bucket in self.parent.active_buckets.iteritems():
bd = BlockDownloader(bucket, blocknum, self)
- d = bd.start(self.segmentnumber)
- l.append(d)
+ downloaders.append(bd)
+ l = [bd.start(self.segmentnumber) for bd in downloaders]
return defer.DeferredList(l)
def hold_block(self, blocknum, data):
self._total_segments = mathutil.div_ceil(size, segment_size)
self._current_segnum = 0
self._segment_size = segment_size
- self._needed_shares = self._decoder.get_needed_shares()
+ self._size = size
+ self._num_needed_shares = self._decoder.get_needed_shares()
+
+ key = "\x00" * 16
+ self._output = Output(downloadable, key)
# future:
# self._share_hash_tree = ??
self.active_buckets = {} # k: shnum, v: bucket
self._share_buckets = {} # k: shnum, v: set of buckets
- key = "\x00" * 16
- self._output = Output(self._downloadable, key)
-
d = defer.maybeDeferred(self._get_all_shareholders)
d.addCallback(self._got_all_shareholders)
d.addCallback(self._download_all_segments)
self._client.log("Somebody failed. -- %s" % (f,))
def _got_all_shareholders(self, res):
- if len(self._share_buckets) < self._needed_shares:
+ if len(self._share_buckets) < self._num_needed_shares:
raise NotEnoughPeersError
self.active_buckets = {}
-
+ self._output.open()
+
def _download_all_segments(self):
d = self._download_segment(self._current_segnum)
def _done(res):
return d
def _download_segment(self, segnum):
- segmentdler = SegmentDownloader(segnum, self._needed_shares)
+ segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
d = segmentdler.start()
- d.addCallback(self._decoder.decode)
+ d.addCallback(lambda (shares, shareids):
+ self._decoder.decode(shares, shareids))
def _done(res):
self._current_segnum += 1
if self._current_segnum == self._total_segments:
data = ''.join(res)
padsize = mathutil.pad_size(self._size, self._segment_size)
data = data[:-padsize]
- self.output.write(data)
+ self._output.write(data)
else:
for buf in res:
- self.output.write(buf)
+ self._output.write(buf)
d.addCallback(_done)
return d
def _done(self, res):
+ self._output.close()
+ #print "VERIFIERID: %s" % idlib.b2a(self._output.verifierid)
+ #print "FILEID: %s" % idlib.b2a(self._output.fileid)
+ #assert self._verifierid == self._output.verifierid
+ #assert self._fileid = self._output.fileid
+ _assert(self._output.length == self._size,
+ got=self._output.length, expected=self._size)
return self._output.finish()
-
- def _write_data(self, data):
- self._verifierid_hasher.update(data)
-
-
-
-# old stuff
- def _got_all_peers(self, res):
- all_buckets = []
- for peerid, buckets in self.landlords:
- all_buckets.extend(buckets)
- # TODO: try to avoid pulling multiple shares from the same peer
- all_buckets = all_buckets[:self.needed_shares]
- # retrieve all shares
- dl = []
- shares = []
- shareids = []
- for (bucket_num, bucket) in all_buckets:
- d0 = bucket.callRemote("get_metadata")
- d1 = bucket.callRemote("read")
- d2 = DeferredListShouldSucceed([d0, d1])
- def _got(res):
- shareid_s, sharedata = res
- shareid = bencode.bdecode(shareid_s)
- shares.append(sharedata)
- shareids.append(shareid)
- d2.addCallback(_got)
- dl.append(d2)
- d = DeferredListShouldSucceed(dl)
-
- d.addCallback(lambda res: self._decoder.decode(shares, shareids))
-
- def _write(decoded_shares):
- data = "".join(decoded_shares)
- self._target.open()
- hasher = sha.new(netstring("allmydata_v1_verifierid"))
- hasher.update(data)
- vid = hasher.digest()
- assert self._verifierid == vid, "%s != %s" % (idlib.b2a(self._verifierid), idlib.b2a(vid))
- self._target.write(data)
- d.addCallback(_write)
- def _done(res):
- self._target.close()
- return self._target.finish()
- def _fail(res):
- self._target.fail()
- return res
- d.addCallbacks(_done, _fail)
- return d
def netstring(s):
return "%d:%s," % (len(s), s)
from twisted.trial import unittest
from twisted.internet import defer
-from allmydata import encode_new
+from allmydata import encode_new, download
+from allmydata.uri import pack_uri
from cStringIO import StringIO
class MyEncoder(encode_new.Encoder):
class FakePeer:
def __init__(self):
self.blocks = {}
- self.blockhashes = None
- self.sharehashes = None
+ self.block_hashes = None
+ self.share_hashes = None
self.closed = False
def callRemote(self, methname, *args, **kwargs):
def put_block_hashes(self, blockhashes):
assert not self.closed
- assert self.blockhashes is None
- self.blockhashes = blockhashes
+ assert self.block_hashes is None
+ self.block_hashes = blockhashes
def put_share_hashes(self, sharehashes):
assert not self.closed
- assert self.sharehashes is None
- self.sharehashes = sharehashes
+ assert self.share_hashes is None
+ self.share_hashes = sharehashes
def close(self):
assert not self.closed
self.closed = True
+ def get_block(self, blocknum):
+ assert isinstance(blocknum, int)
+ return self.blocks[blocknum]
+
+ def get_block_hashes(self):
+ return self.block_hashes
+ def get_share_hashes(self):
+ return self.share_hashes
+
+
class UpDown(unittest.TestCase):
def test_send(self):
e = encode_new.Encoder()
for i,peer in enumerate(all_shareholders):
self.failUnless(peer.closed)
self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
- #self.failUnlessEqual(len(peer.blockhashes), NUM_SEGMENTS)
+ #self.failUnlessEqual(len(peer.block_hashes), NUM_SEGMENTS)
# that isn't true: each peer gets a full tree, so it's more
# like 2n-1 but with rounding to a power of two
- for h in peer.blockhashes:
+ for h in peer.block_hashes:
self.failUnlessEqual(len(h), 32)
- #self.failUnlessEqual(len(peer.sharehashes), NUM_SHARES)
+ #self.failUnlessEqual(len(peer.share_hashes), NUM_SHARES)
# that isn't true: each peer only gets the chain they need
- for (hashnum, h) in peer.sharehashes:
+ for (hashnum, h) in peer.share_hashes:
self.failUnless(isinstance(hashnum, int))
self.failUnlessEqual(len(h), 32)
d.addCallback(_check)
return d
+
+ def test_send_and_recover(self):
+ e = encode_new.Encoder()
+ data = "happy happy joy joy" * 4
+ e.setup(StringIO(data))
+ NUM_SHARES = 100
+ assert e.num_shares == NUM_SHARES # else we'll be completely confused
+ e.segment_size = 25 # force use of multiple segments
+ e.setup_codec() # need to rebuild the codec for that change
+ NUM_SEGMENTS = 4
+ assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
+ shareholders = {}
+ all_shareholders = []
+ for shnum in range(NUM_SHARES):
+ peer = FakePeer()
+ shareholders[shnum] = peer
+ all_shareholders.append(peer)
+ e.set_shareholders(shareholders)
+ d = e.start()
+ def _uploaded(roothash):
+ URI = pack_uri(e._codec.get_encoder_type(),
+ e._codec.get_serialized_params(),
+ "V" * 20,
+ roothash,
+ e.required_shares,
+ e.num_shares,
+ e.file_size,
+ e.segment_size)
+ client = None
+ target = download.Data()
+ fd = download.FileDownloader(client, URI, target)
+ fd._share_buckets = {}
+ for shnum in range(NUM_SHARES):
+ fd._share_buckets[shnum] = set([all_shareholders[shnum]])
+ fd._got_all_shareholders(None)
+ d2 = fd._download_all_segments()
+ d2.addCallback(fd._done)
+ return d2
+ d.addCallback(_uploaded)
+ def _downloaded(newdata):
+ self.failUnless(newdata == data)
+ d.addCallback(_downloaded)
+
+ return d