2 from zope.interface import implements
3 from twisted.trial import unittest
4 from twisted.internet import defer
5 from twisted.python.failure import Failure
6 from allmydata import encode, upload, download, hashtree, uri
7 from allmydata.util import hashutil
8 from allmydata.util.assertutil import _assert
9 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
11 class LostPeerError(Exception):
14 def flip_bit(good): # flips the last bit
15 return good[:-1] + chr(ord(good[-1]) ^ 0x01)
18 def log(self, *args, **kwargs):
21 class FakeBucketWriterProxy:
22 implements(IStorageBucketWriter, IStorageBucketReader)
23 # these are used for both reading and writing
24 def __init__(self, mode="good"):
27 self.plaintext_hashes = None
28 self.crypttext_hashes = None
29 self.block_hashes = None
30 self.share_hashes = None
33 def startIfNecessary(self):
34 return defer.succeed(self)
36 return defer.succeed(self)
38 def put_block(self, segmentnum, data):
40 assert not self.closed
41 assert segmentnum not in self.blocks
42 if self.mode == "lost" and segmentnum >= 1:
43 raise LostPeerError("I'm going away now")
44 self.blocks[segmentnum] = data
45 return defer.maybeDeferred(_try)
47 def put_plaintext_hashes(self, hashes):
49 assert not self.closed
50 assert self.plaintext_hashes is None
51 self.plaintext_hashes = hashes
52 return defer.maybeDeferred(_try)
54 def put_crypttext_hashes(self, hashes):
56 assert not self.closed
57 assert self.crypttext_hashes is None
58 self.crypttext_hashes = hashes
59 return defer.maybeDeferred(_try)
61 def put_block_hashes(self, blockhashes):
63 assert not self.closed
64 assert self.block_hashes is None
65 self.block_hashes = blockhashes
66 return defer.maybeDeferred(_try)
68 def put_share_hashes(self, sharehashes):
70 assert not self.closed
71 assert self.share_hashes is None
72 self.share_hashes = sharehashes
73 return defer.maybeDeferred(_try)
75 def put_uri_extension(self, uri_extension):
77 assert not self.closed
78 self.uri_extension = uri_extension
79 return defer.maybeDeferred(_try)
83 assert not self.closed
85 return defer.maybeDeferred(_try)
88 return defer.succeed(None)
90 def get_block(self, blocknum):
92 assert isinstance(blocknum, (int, long))
93 if self.mode == "bad block":
94 return flip_bit(self.blocks[blocknum])
95 return self.blocks[blocknum]
96 return defer.maybeDeferred(_try)
98 def get_plaintext_hashes(self):
100 hashes = self.plaintext_hashes[:]
101 if self.mode == "bad plaintext hashroot":
102 hashes[0] = flip_bit(hashes[0])
103 if self.mode == "bad plaintext hash":
104 hashes[1] = flip_bit(hashes[1])
106 return defer.maybeDeferred(_try)
108 def get_crypttext_hashes(self):
110 hashes = self.crypttext_hashes[:]
111 if self.mode == "bad crypttext hashroot":
112 hashes[0] = flip_bit(hashes[0])
113 if self.mode == "bad crypttext hash":
114 hashes[1] = flip_bit(hashes[1])
116 return defer.maybeDeferred(_try)
118 def get_block_hashes(self):
120 if self.mode == "bad blockhash":
121 hashes = self.block_hashes[:]
122 hashes[1] = flip_bit(hashes[1])
124 return self.block_hashes
125 return defer.maybeDeferred(_try)
127 def get_share_hashes(self):
129 if self.mode == "bad sharehash":
130 hashes = self.share_hashes[:]
131 hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
133 if self.mode == "missing sharehash":
134 # one sneaky attack would be to pretend we don't know our own
135 # sharehash, which could manage to frame someone else.
136 # download.py is supposed to guard against this case.
138 return self.share_hashes
139 return defer.maybeDeferred(_try)
141 def get_uri_extension(self):
143 if self.mode == "bad uri_extension":
144 return flip_bit(self.uri_extension)
145 return self.uri_extension
146 return defer.maybeDeferred(_try)
149 def make_data(length):
150 data = "happy happy joy joy" * 100
151 assert length <= len(data)
154 class Encode(unittest.TestCase):
156 def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
157 expected_block_hashes, expected_share_hashes):
158 data = make_data(datalen)
159 # force use of multiple segments
160 options = {"max_segment_size": max_segment_size, 'needed_and_happy_and_total_shares': (25, 75, 100)}
161 e = encode.Encoder(options)
162 u = upload.Data(data)
163 eu = upload.EncryptAnUploadable(u)
164 d = e.set_encrypted_uploadable(eu)
166 all_shareholders = []
168 k,happy,n = e.get_param("share_counts")
169 _assert(n == NUM_SHARES) # else we'll be completely confused
170 numsegs = e.get_param("num_segments")
171 _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS)
172 segsize = e.get_param("segment_size")
173 _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize,
174 NUM_SEGMENTS, segsize,
175 (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize)
178 for shnum in range(NUM_SHARES):
179 peer = FakeBucketWriterProxy()
180 shareholders[shnum] = peer
181 all_shareholders.append(peer)
182 e.set_shareholders(shareholders)
184 d.addCallback(_ready)
187 (uri_extension_hash, required_shares, num_shares, file_size) = res
188 self.failUnless(isinstance(uri_extension_hash, str))
189 self.failUnlessEqual(len(uri_extension_hash), 32)
190 for i,peer in enumerate(all_shareholders):
191 self.failUnless(peer.closed)
192 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
193 # each peer gets a full tree of block hashes. For 3 or 4
194 # segments, that's 7 hashes. For 5 segments it's 15 hashes.
195 self.failUnlessEqual(len(peer.block_hashes),
196 expected_block_hashes)
197 for h in peer.block_hashes:
198 self.failUnlessEqual(len(h), 32)
199 # each peer also gets their necessary chain of share hashes.
200 # For 100 shares (rounded up to 128 leaves), that's 8 hashes
201 self.failUnlessEqual(len(peer.share_hashes),
202 expected_share_hashes)
203 for (hashnum, h) in peer.share_hashes:
204 self.failUnless(isinstance(hashnum, int))
205 self.failUnlessEqual(len(h), 32)
206 d.addCallback(_check)
210 # a series of 3*3 tests to check out edge conditions. One axis is how the
211 # plaintext is divided into segments: kn+(-1,0,1). Another way to express
212 # that is that n%k == -1 or 0 or 1. For example, for 25-byte segments, we
213 # might test 74 bytes, 75 bytes, and 76 bytes.
215 # on the other axis is how many leaves in the block hash tree we wind up
216 # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
217 # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
218 # segments, and 5 segments.
220 # that results in the following series of data lengths:
222 # 4 segs: 99, 100, 76
223 # 5 segs: 124, 125, 101
225 # all tests encode to 100 shares, which means the share hash tree will
226 # have 128 leaves, which means that buckets will be given an 8-long share
229 # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
230 # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
231 # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
232 # trees, which get 15 blockhashes.
234 def test_send_74(self):
235 # 3 segments (25, 25, 24)
236 return self.do_encode(25, 74, 100, 3, 7, 8)
237 def test_send_75(self):
238 # 3 segments (25, 25, 25)
239 return self.do_encode(25, 75, 100, 3, 7, 8)
240 def test_send_51(self):
241 # 3 segments (25, 25, 1)
242 return self.do_encode(25, 51, 100, 3, 7, 8)
244 def test_send_76(self):
245 # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
246 return self.do_encode(25, 76, 100, 4, 7, 8)
247 def test_send_99(self):
248 # 4 segments: 25,25,25,24
249 return self.do_encode(25, 99, 100, 4, 7, 8)
250 def test_send_100(self):
251 # 4 segments: 25,25,25,25
252 return self.do_encode(25, 100, 100, 4, 7, 8)
254 def test_send_124(self):
255 # 5 segments: 25, 25, 25, 25, 24
256 return self.do_encode(25, 124, 100, 5, 15, 8)
257 def test_send_125(self):
258 # 5 segments: 25, 25, 25, 25, 25
259 return self.do_encode(25, 125, 100, 5, 15, 8)
260 def test_send_101(self):
261 # 5 segments: 25, 25, 25, 25, 1
262 return self.do_encode(25, 101, 100, 5, 15, 8)
264 class Roundtrip(unittest.TestCase):
265 def send_and_recover(self, k_and_happy_and_n=(25,75,100),
266 AVAILABLE_SHARES=None,
270 recover_mode="recover",
272 if AVAILABLE_SHARES is None:
273 AVAILABLE_SHARES = k_and_happy_and_n[2]
274 data = make_data(datalen)
275 d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
276 max_segment_size, bucket_modes, data)
277 # that fires with (uri_extension_hash, e, shareholders)
278 d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
279 # that fires with newdata
280 def _downloaded((newdata, fd)):
281 self.failUnless(newdata == data)
283 d.addCallback(_downloaded)
286 def send(self, k_and_happy_and_n, AVAILABLE_SHARES, max_segment_size,
288 NUM_SHARES = k_and_happy_and_n[2]
289 if AVAILABLE_SHARES is None:
290 AVAILABLE_SHARES = NUM_SHARES
291 # force use of multiple segments
292 options = {"max_segment_size": max_segment_size,
293 "needed_and_happy_and_total_shares": k_and_happy_and_n}
294 e = encode.Encoder(options)
295 u = upload.Data(data)
296 eu = upload.EncryptAnUploadable(u)
297 d = e.set_encrypted_uploadable(eu)
301 k,happy,n = e.get_param("share_counts")
302 assert n == NUM_SHARES # else we'll be completely confused
304 for shnum in range(NUM_SHARES):
305 mode = bucket_modes.get(shnum, "good")
306 peer = FakeBucketWriterProxy(mode)
307 shareholders[shnum] = peer
308 e.set_shareholders(shareholders)
310 d.addCallback(_ready)
312 d1 = u.get_encryption_key()
313 d1.addCallback(lambda key: (res, key, shareholders))
318 def recover(self, (res, key, shareholders), AVAILABLE_SHARES,
320 (uri_extension_hash, required_shares, num_shares, file_size) = res
322 if "corrupt_key" in recover_mode:
323 # we corrupt the key, so that the decrypted data is corrupted and
324 # will fail the plaintext hash check. Since we're manually
325 # attaching shareholders, the fact that the storage index is also
326 # corrupted doesn't matter.
329 u = uri.CHKFileURI(key=key,
330 uri_extension_hash=uri_extension_hash,
331 needed_shares=required_shares,
332 total_shares=num_shares,
336 client = FakeClient()
337 target = download.Data()
338 fd = download.FileDownloader(client, URI, target)
340 # we manually cycle the FileDownloader through a number of steps that
341 # would normally be sequenced by a Deferred chain in
342 # FileDownloader.start(), to give us more control over the process.
343 # In particular, by bypassing _get_all_shareholders, we skip
344 # permuted-peerlist selection.
345 for shnum, bucket in shareholders.items():
346 if shnum < AVAILABLE_SHARES and bucket.closed:
347 fd.add_share_bucket(shnum, bucket)
348 fd._got_all_shareholders(None)
350 # Make it possible to obtain uri_extension from the shareholders.
351 # Arrange for shareholders[0] to be the first, so we can selectively
352 # corrupt the data it returns.
353 fd._uri_extension_sources = shareholders.values()
354 fd._uri_extension_sources.remove(shareholders[0])
355 fd._uri_extension_sources.insert(0, shareholders[0])
357 d = defer.succeed(None)
359 # have the FileDownloader retrieve a copy of uri_extension itself
360 d.addCallback(fd._obtain_uri_extension)
362 if "corrupt_crypttext_hashes" in recover_mode:
363 # replace everybody's crypttext hash trees with a different one
364 # (computed over a different file), then modify our uri_extension
365 # to reflect the new crypttext hash tree root
366 def _corrupt_crypttext_hashes(uri_extension):
367 assert isinstance(uri_extension, dict)
368 assert 'crypttext_root_hash' in uri_extension
369 badhash = hashutil.tagged_hash("bogus", "data")
370 bad_crypttext_hashes = [badhash] * uri_extension['num_segments']
371 badtree = hashtree.HashTree(bad_crypttext_hashes)
372 for bucket in shareholders.values():
373 bucket.crypttext_hashes = list(badtree)
374 uri_extension['crypttext_root_hash'] = badtree[0]
376 d.addCallback(_corrupt_crypttext_hashes)
378 d.addCallback(fd._got_uri_extension)
380 # also have the FileDownloader ask for hash trees
381 d.addCallback(fd._get_hashtrees)
383 d.addCallback(fd._create_validated_buckets)
384 d.addCallback(fd._download_all_segments)
385 d.addCallback(fd._done)
391 def test_not_enough_shares(self):
392 d = self.send_and_recover((4,8,10), AVAILABLE_SHARES=2)
394 self.failUnless(isinstance(res, Failure))
395 self.failUnless(res.check(download.NotEnoughPeersError))
399 def test_one_share_per_peer(self):
400 return self.send_and_recover()
403 return self.send_and_recover(datalen=74)
405 return self.send_and_recover(datalen=75)
407 return self.send_and_recover(datalen=51)
410 return self.send_and_recover(datalen=99)
412 return self.send_and_recover(datalen=100)
414 return self.send_and_recover(datalen=76)
417 return self.send_and_recover(datalen=124)
419 return self.send_and_recover(datalen=125)
421 return self.send_and_recover(datalen=101)
423 # the following tests all use 4-out-of-10 encoding
425 def test_bad_blocks(self):
426 # the first 6 servers have bad blocks, which will be caught by the
428 modemap = dict([(i, "bad block")
431 for i in range(6, 10)])
432 return self.send_and_recover((4,8,10), bucket_modes=modemap)
434 def test_bad_blocks_failure(self):
435 # the first 7 servers have bad blocks, which will be caught by the
436 # blockhashes, and the download will fail
437 modemap = dict([(i, "bad block")
440 for i in range(7, 10)])
441 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
443 self.failUnless(isinstance(res, Failure))
444 self.failUnless(res.check(download.NotEnoughPeersError))
448 def test_bad_blockhashes(self):
449 # the first 6 servers have bad block hashes, so the blockhash tree
451 modemap = dict([(i, "bad blockhash")
454 for i in range(6, 10)])
455 return self.send_and_recover((4,8,10), bucket_modes=modemap)
457 def test_bad_blockhashes_failure(self):
458 # the first 7 servers have bad block hashes, so the blockhash tree
459 # will not validate, and the download will fail
460 modemap = dict([(i, "bad blockhash")
463 for i in range(7, 10)])
464 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
466 self.failUnless(isinstance(res, Failure))
467 self.failUnless(res.check(download.NotEnoughPeersError))
471 def test_bad_sharehashes(self):
472 # the first 6 servers have bad block hashes, so the sharehash tree
474 modemap = dict([(i, "bad sharehash")
477 for i in range(6, 10)])
478 return self.send_and_recover((4,8,10), bucket_modes=modemap)
480 def assertFetchFailureIn(self, fd, where):
481 expected = {"uri_extension": 0,
482 "plaintext_hashroot": 0,
483 "plaintext_hashtree": 0,
484 "crypttext_hashroot": 0,
485 "crypttext_hashtree": 0,
487 if where is not None:
489 self.failUnlessEqual(fd._fetch_failures, expected)
492 # just to make sure the test harness works when we aren't
493 # intentionally causing failures
494 modemap = dict([(i, "good") for i in range(0, 10)])
495 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
496 d.addCallback(self.assertFetchFailureIn, None)
499 def test_bad_uri_extension(self):
500 # the first server has a bad uri_extension block, so we will fail
501 # over to a different server.
502 modemap = dict([(i, "bad uri_extension") for i in range(1)] +
503 [(i, "good") for i in range(1, 10)])
504 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
505 d.addCallback(self.assertFetchFailureIn, "uri_extension")
508 def test_bad_plaintext_hashroot(self):
509 # the first server has a bad plaintext hashroot, so we will fail over
510 # to a different server.
511 modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] +
512 [(i, "good") for i in range(1, 10)])
513 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
514 d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot")
517 def test_bad_crypttext_hashroot(self):
518 # the first server has a bad crypttext hashroot, so we will fail
519 # over to a different server.
520 modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] +
521 [(i, "good") for i in range(1, 10)])
522 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
523 d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot")
526 def test_bad_plaintext_hashes(self):
527 # the first server has a bad plaintext hash block, so we will fail
528 # over to a different server.
529 modemap = dict([(i, "bad plaintext hash") for i in range(1)] +
530 [(i, "good") for i in range(1, 10)])
531 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
532 d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree")
535 def test_bad_crypttext_hashes(self):
536 # the first server has a bad crypttext hash block, so we will fail
537 # over to a different server.
538 modemap = dict([(i, "bad crypttext hash") for i in range(1)] +
539 [(i, "good") for i in range(1, 10)])
540 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
541 d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree")
544 def test_bad_crypttext_hashes_failure(self):
545 # to test that the crypttext merkle tree is really being applied, we
546 # sneak into the download process and corrupt two things: we replace
547 # everybody's crypttext hashtree with a bad version (computed over
548 # bogus data), and we modify the supposedly-validated uri_extension
549 # block to match the new crypttext hashtree root. The download
550 # process should notice that the crypttext coming out of FEC doesn't
551 # match the tree, and fail.
553 modemap = dict([(i, "good") for i in range(0, 10)])
554 d = self.send_and_recover((4,8,10), bucket_modes=modemap,
555 recover_mode=("corrupt_crypttext_hashes"))
557 self.failUnless(isinstance(res, Failure))
558 self.failUnless(res.check(hashtree.BadHashError), res)
563 def test_bad_plaintext(self):
564 # faking a decryption failure is easier: just corrupt the key
565 modemap = dict([(i, "good") for i in range(0, 10)])
566 d = self.send_and_recover((4,8,10), bucket_modes=modemap,
567 recover_mode=("corrupt_key"))
569 self.failUnless(isinstance(res, Failure))
570 self.failUnless(res.check(hashtree.BadHashError), res)
574 def test_bad_sharehashes_failure(self):
575 # the first 7 servers have bad block hashes, so the sharehash tree
576 # will not validate, and the download will fail
577 modemap = dict([(i, "bad sharehash")
580 for i in range(7, 10)])
581 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
583 self.failUnless(isinstance(res, Failure))
584 self.failUnless(res.check(download.NotEnoughPeersError))
588 def test_missing_sharehashes(self):
589 # the first 6 servers are missing their sharehashes, so the
590 # sharehash tree will not validate
591 modemap = dict([(i, "missing sharehash")
594 for i in range(6, 10)])
595 return self.send_and_recover((4,8,10), bucket_modes=modemap)
597 def test_missing_sharehashes_failure(self):
598 # the first 7 servers are missing their sharehashes, so the
599 # sharehash tree will not validate, and the download will fail
600 modemap = dict([(i, "missing sharehash")
603 for i in range(7, 10)])
604 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
606 self.failUnless(isinstance(res, Failure))
607 self.failUnless(res.check(download.NotEnoughPeersError))
611 def test_lost_one_shareholder(self):
612 # we have enough shareholders when we start, but one segment in we
613 # lose one of them. The upload should still succeed, as long as we
614 # still have 'shares_of_happiness' peers left.
615 modemap = dict([(i, "good") for i in range(9)] +
616 [(i, "lost") for i in range(9, 10)])
617 return self.send_and_recover((4,8,10), bucket_modes=modemap)
619 def test_lost_many_shareholders(self):
620 # we have enough shareholders when we start, but one segment in we
621 # lose all but one of them. The upload should fail.
622 modemap = dict([(i, "good") for i in range(1)] +
623 [(i, "lost") for i in range(1, 10)])
624 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
626 self.failUnless(isinstance(res, Failure))
627 self.failUnless(res.check(encode.NotEnoughPeersError), res)
631 def test_lost_all_shareholders(self):
632 # we have enough shareholders when we start, but one segment in we
633 # lose all of them. The upload should fail.
634 modemap = dict([(i, "lost") for i in range(10)])
635 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
637 self.failUnless(isinstance(res, Failure))
638 self.failUnless(res.check(encode.NotEnoughPeersError))