3 from twisted.trial import unittest
4 from twisted.internet import defer
5 from foolscap import eventual
6 from allmydata import encode, download
7 from allmydata.uri import pack_uri
8 from cStringIO import StringIO
11 def __init__(self, mode="good"):
12 self.ss = FakeStorageServer(mode)
14 def callRemote(self, methname, *args, **kwargs):
16 meth = getattr(self, methname)
17 return meth(*args, **kwargs)
18 return defer.maybeDeferred(_call)
20 def get_service(self, sname):
21 assert sname == "storageserver"
24 class FakeStorageServer:
25 def __init__(self, mode):
27 def callRemote(self, methname, *args, **kwargs):
29 meth = getattr(self, methname)
30 return meth(*args, **kwargs)
31 d = eventual.fireEventually()
32 d.addCallback(lambda res: _call())
34 def allocate_buckets(self, verifierid, sharenums, shareize, blocksize, canary):
35 if self.mode == "full":
37 elif self.mode == "already got them":
38 return (set(sharenums), {},)
40 return (set(), dict([(shnum, FakeBucketWriter(),) for shnum in sharenums]),)
42 class FakeBucketWriter:
45 self.block_hashes = None
46 self.share_hashes = None
49 def callRemote(self, methname, *args, **kwargs):
51 meth = getattr(self, methname)
52 return meth(*args, **kwargs)
53 return defer.maybeDeferred(_call)
55 def put_block(self, segmentnum, data):
56 assert not self.closed
57 assert segmentnum not in self.blocks
58 self.blocks[segmentnum] = data
60 def put_block_hashes(self, blockhashes):
61 assert not self.closed
62 assert self.block_hashes is None
63 self.block_hashes = blockhashes
65 def put_share_hashes(self, sharehashes):
66 assert not self.closed
67 assert self.share_hashes is None
68 self.share_hashes = sharehashes
71 assert not self.closed
75 def get_block(self, blocknum):
76 assert isinstance(blocknum, int)
77 return self.blocks[blocknum]
79 def get_block_hashes(self):
80 return self.block_hashes
81 def get_share_hashes(self):
82 return self.share_hashes
85 class Encode(unittest.TestCase):
88 data = "happy happy joy joy" * 4
89 e.setup(StringIO(data))
91 assert e.num_shares == NUM_SHARES # else we'll be completely confused
92 e.segment_size = 25 # force use of multiple segments
93 e.setup_codec() # need to rebuild the codec for that change
95 assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
98 for shnum in range(NUM_SHARES):
99 peer = FakeBucketWriter()
100 shareholders[shnum] = peer
101 all_shareholders.append(peer)
102 e.set_shareholders(shareholders)
104 def _check(roothash):
105 self.failUnless(isinstance(roothash, str))
106 self.failUnlessEqual(len(roothash), 32)
107 for i,peer in enumerate(all_shareholders):
108 self.failUnless(peer.closed)
109 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
110 #self.failUnlessEqual(len(peer.block_hashes), NUM_SEGMENTS)
111 # that isn't true: each peer gets a full tree, so it's more
112 # like 2n-1 but with rounding to a power of two
113 for h in peer.block_hashes:
114 self.failUnlessEqual(len(h), 32)
115 #self.failUnlessEqual(len(peer.share_hashes), NUM_SHARES)
116 # that isn't true: each peer only gets the chain they need
117 for (hashnum, h) in peer.share_hashes:
118 self.failUnless(isinstance(hashnum, int))
119 self.failUnlessEqual(len(h), 32)
120 d.addCallback(_check)
124 class Roundtrip(unittest.TestCase):
125 def send_and_recover(self, NUM_SHARES, NUM_PEERS, NUM_SEGMENTS=4):
127 data = "happy happy joy joy" * 4
128 e.setup(StringIO(data))
130 assert e.num_shares == NUM_SHARES # else we'll be completely confused
131 e.segment_size = 25 # force use of multiple segments
132 e.setup_codec() # need to rebuild the codec for that change
134 assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
136 all_shareholders = []
138 for i in range(NUM_PEERS):
139 all_peers.append(FakeBucketWriter())
140 for shnum in range(NUM_SHARES):
141 peer = all_peers[shnum % NUM_PEERS]
142 shareholders[shnum] = peer
143 all_shareholders.append(peer)
144 e.set_shareholders(shareholders)
146 def _uploaded(roothash):
147 URI = pack_uri(e._codec.get_encoder_type(),
148 e._codec.get_serialized_params(),
149 e._tail_codec.get_serialized_params(),
157 target = download.Data()
158 fd = download.FileDownloader(client, URI, target)
159 for shnum in range(NUM_SHARES):
160 bucket = all_shareholders[shnum]
161 fd.add_share_bucket(shnum, bucket)
162 fd._got_all_shareholders(None)
163 d2 = fd._download_all_segments(None)
164 d2.addCallback(fd._done)
166 d.addCallback(_uploaded)
167 def _downloaded(newdata):
168 self.failUnless(newdata == data)
169 d.addCallback(_downloaded)
173 def test_one_share_per_peer(self):
174 return self.send_and_recover(100, 100)