2 from zope.interface import implements
3 from twisted.trial import unittest
4 from twisted.internet import defer
5 from twisted.python.failure import Failure
6 from foolscap import eventual
7 from allmydata import encode, download, hashtree
8 from allmydata.util import hashutil
9 from allmydata.uri import pack_uri
10 from allmydata.Crypto.Cipher import AES
11 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
12 from cStringIO import StringIO
15 def __init__(self, mode="good"):
16 self.ss = FakeStorageServer(mode)
18 def callRemote(self, methname, *args, **kwargs):
20 meth = getattr(self, methname)
21 return meth(*args, **kwargs)
22 return defer.maybeDeferred(_call)
24 def get_service(self, sname):
25 assert sname == "storageserver"
28 class FakeStorageServer:
29 def __init__(self, mode):
31 def callRemote(self, methname, *args, **kwargs):
33 meth = getattr(self, methname)
34 return meth(*args, **kwargs)
35 d = eventual.fireEventually()
36 d.addCallback(lambda res: _call())
38 def allocate_buckets(self, crypttext_hash, sharenums, shareize, blocksize, canary):
39 if self.mode == "full":
41 elif self.mode == "already got them":
42 return (set(sharenums), {},)
44 return (set(), dict([(shnum, FakeBucketWriter(),) for shnum in sharenums]),)
46 class LostPeerError(Exception):
49 def flip_bit(good): # flips the last bit
50 return good[:-1] + chr(ord(good[-1]) ^ 0x01)
52 class FakeBucketWriter:
53 implements(IStorageBucketWriter, IStorageBucketReader)
54 # these are used for both reading and writing
55 def __init__(self, mode="good"):
58 self.plaintext_hashes = None
59 self.crypttext_hashes = None
60 self.block_hashes = None
61 self.share_hashes = None
64 def startIfNecessary(self):
65 return defer.succeed(self)
67 return defer.succeed(self)
69 def put_block(self, segmentnum, data):
71 assert not self.closed
72 assert segmentnum not in self.blocks
73 if self.mode == "lost" and segmentnum >= 1:
74 raise LostPeerError("I'm going away now")
75 self.blocks[segmentnum] = data
76 return defer.maybeDeferred(_try)
78 def put_plaintext_hashes(self, hashes):
80 assert not self.closed
81 assert self.plaintext_hashes is None
82 self.plaintext_hashes = hashes
83 return defer.maybeDeferred(_try)
85 def put_crypttext_hashes(self, hashes):
87 assert not self.closed
88 assert self.crypttext_hashes is None
89 self.crypttext_hashes = hashes
90 return defer.maybeDeferred(_try)
92 def put_block_hashes(self, blockhashes):
94 assert not self.closed
95 assert self.block_hashes is None
96 self.block_hashes = blockhashes
97 return defer.maybeDeferred(_try)
99 def put_share_hashes(self, sharehashes):
101 assert not self.closed
102 assert self.share_hashes is None
103 self.share_hashes = sharehashes
104 return defer.maybeDeferred(_try)
106 def put_uri_extension(self, uri_extension):
108 assert not self.closed
109 self.uri_extension = uri_extension
110 return defer.maybeDeferred(_try)
114 assert not self.closed
116 return defer.maybeDeferred(_try)
118 def get_block(self, blocknum):
120 assert isinstance(blocknum, (int, long))
121 if self.mode == "bad block":
122 return flip_bit(self.blocks[blocknum])
123 return self.blocks[blocknum]
124 return defer.maybeDeferred(_try)
126 def get_plaintext_hashes(self):
128 hashes = self.plaintext_hashes[:]
129 if self.mode == "bad plaintext hashroot":
130 hashes[0] = flip_bit(hashes[0])
131 if self.mode == "bad plaintext hash":
132 hashes[1] = flip_bit(hashes[1])
134 return defer.maybeDeferred(_try)
136 def get_crypttext_hashes(self):
138 hashes = self.crypttext_hashes[:]
139 if self.mode == "bad crypttext hashroot":
140 hashes[0] = flip_bit(hashes[0])
141 if self.mode == "bad crypttext hash":
142 hashes[1] = flip_bit(hashes[1])
144 return defer.maybeDeferred(_try)
146 def get_block_hashes(self):
148 if self.mode == "bad blockhash":
149 hashes = self.block_hashes[:]
150 hashes[1] = flip_bit(hashes[1])
152 return self.block_hashes
153 return defer.maybeDeferred(_try)
155 def get_share_hashes(self):
157 if self.mode == "bad sharehash":
158 hashes = self.share_hashes[:]
159 hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
161 if self.mode == "missing sharehash":
162 # one sneaky attack would be to pretend we don't know our own
163 # sharehash, which could manage to frame someone else.
164 # download.py is supposed to guard against this case.
166 return self.share_hashes
167 return defer.maybeDeferred(_try)
169 def get_uri_extension(self):
171 if self.mode == "bad uri_extension":
172 return flip_bit(self.uri_extension)
173 return self.uri_extension
174 return defer.maybeDeferred(_try)
177 def make_data(length):
178 data = "happy happy joy joy" * 100
179 assert length <= len(data)
182 class Encode(unittest.TestCase):
184 def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
185 expected_block_hashes, expected_share_hashes):
186 data = make_data(datalen)
187 # force use of multiple segments
188 options = {"max_segment_size": max_segment_size}
189 e = encode.Encoder(options)
191 e.setup(StringIO(data), nonkey)
192 assert e.num_shares == NUM_SHARES # else we'll be completely confused
193 e.setup_codec() # need to rebuild the codec for that change
194 assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
196 all_shareholders = []
197 for shnum in range(NUM_SHARES):
198 peer = FakeBucketWriter()
199 shareholders[shnum] = peer
200 all_shareholders.append(peer)
201 e.set_shareholders(shareholders)
203 def _check(roothash):
204 self.failUnless(isinstance(roothash, str))
205 self.failUnlessEqual(len(roothash), 32)
206 for i,peer in enumerate(all_shareholders):
207 self.failUnless(peer.closed)
208 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
209 # each peer gets a full tree of block hashes. For 3 or 4
210 # segments, that's 7 hashes. For 5 segments it's 15 hashes.
211 self.failUnlessEqual(len(peer.block_hashes),
212 expected_block_hashes)
213 for h in peer.block_hashes:
214 self.failUnlessEqual(len(h), 32)
215 # each peer also gets their necessary chain of share hashes.
216 # For 100 shares (rounded up to 128 leaves), that's 8 hashes
217 self.failUnlessEqual(len(peer.share_hashes),
218 expected_share_hashes)
219 for (hashnum, h) in peer.share_hashes:
220 self.failUnless(isinstance(hashnum, int))
221 self.failUnlessEqual(len(h), 32)
222 d.addCallback(_check)
226 # a series of 3*3 tests to check out edge conditions. One axis is how the
227 # plaintext is divided into segments: kn+(-1,0,1). Another way to express
228 # that is that n%k == -1 or 0 or 1. For example, for 25-byte segments, we
229 # might test 74 bytes, 75 bytes, and 76 bytes.
231 # on the other axis is how many leaves in the block hash tree we wind up
232 # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
233 # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
234 # segments, and 5 segments.
236 # that results in the following series of data lengths:
238 # 4 segs: 99, 100, 76
239 # 5 segs: 124, 125, 101
241 # all tests encode to 100 shares, which means the share hash tree will
242 # have 128 leaves, which means that buckets will be given an 8-long share
245 # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
246 # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
247 # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
248 # trees, which get 15 blockhashes.
250 def test_send_74(self):
251 # 3 segments (25, 25, 24)
252 return self.do_encode(25, 74, 100, 3, 7, 8)
253 def test_send_75(self):
254 # 3 segments (25, 25, 25)
255 return self.do_encode(25, 75, 100, 3, 7, 8)
256 def test_send_51(self):
257 # 3 segments (25, 25, 1)
258 return self.do_encode(25, 51, 100, 3, 7, 8)
260 def test_send_76(self):
261 # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
262 return self.do_encode(25, 76, 100, 4, 7, 8)
263 def test_send_99(self):
264 # 4 segments: 25,25,25,24
265 return self.do_encode(25, 99, 100, 4, 7, 8)
266 def test_send_100(self):
267 # 4 segments: 25,25,25,25
268 return self.do_encode(25, 100, 100, 4, 7, 8)
270 def test_send_101(self):
271 # encode a 101 byte file (in 5 segments: 25,25,25,25,1) to 100 shares
272 return self.do_encode(25, self.make_data(101), 100, 5, 15, 8)
274 def test_send_124(self):
275 # 5 segments: 25, 25, 25, 25, 24
276 return self.do_encode(25, 124, 100, 5, 15, 8)
277 def test_send_125(self):
278 # 5 segments: 25, 25, 25, 25, 25
279 return self.do_encode(25, 125, 100, 5, 15, 8)
280 def test_send_101(self):
281 # 5 segments: 25, 25, 25, 25, 1
282 return self.do_encode(25, 101, 100, 5, 15, 8)
284 class Roundtrip(unittest.TestCase):
285 def send_and_recover(self, k_and_happy_and_n=(25,75,100),
286 AVAILABLE_SHARES=None,
290 recover_mode="recover",
292 if AVAILABLE_SHARES is None:
293 AVAILABLE_SHARES = k_and_happy_and_n[2]
294 data = make_data(datalen)
295 d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
296 max_segment_size, bucket_modes, data)
297 # that fires with (uri_extension_hash, e, shareholders)
298 d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
299 # that fires with newdata
300 def _downloaded((newdata, fd)):
301 self.failUnless(newdata == data)
303 d.addCallback(_downloaded)
306 def send(self, k_and_happy_and_n, AVAILABLE_SHARES, max_segment_size,
308 NUM_SHARES = k_and_happy_and_n[2]
309 if AVAILABLE_SHARES is None:
310 AVAILABLE_SHARES = NUM_SHARES
311 # force use of multiple segments
312 options = {"max_segment_size": max_segment_size,
313 "needed_and_happy_and_total_shares": k_and_happy_and_n}
314 e = encode.Encoder(options)
316 e.setup(StringIO(data), nonkey)
318 assert e.num_shares == NUM_SHARES # else we'll be completely confused
319 e.setup_codec() # need to rebuild the codec for that change
323 for shnum in range(NUM_SHARES):
324 mode = bucket_modes.get(shnum, "good")
325 peer = FakeBucketWriter(mode)
326 shareholders[shnum] = peer
327 e.set_shareholders(shareholders)
328 plaintext_hasher = hashutil.plaintext_hasher()
329 plaintext_hasher.update(data)
330 cryptor = AES.new(key=nonkey, mode=AES.MODE_CTR,
331 counterstart="\x00"*16)
332 crypttext_hasher = hashutil.crypttext_hasher()
333 crypttext_hasher.update(cryptor.encrypt(data))
335 e.set_uri_extension_data({'crypttext_hash': crypttext_hasher.digest(),
336 'plaintext_hash': plaintext_hasher.digest(),
339 def _sent(uri_extension_hash):
340 return (uri_extension_hash, e, shareholders)
344 def recover(self, (uri_extension_hash, e, shareholders), AVAILABLE_SHARES,
347 if "corrupt_key" in recover_mode:
350 URI = pack_uri(storage_index="S" * 32,
352 uri_extension_hash=uri_extension_hash,
353 needed_shares=e.required_shares,
354 total_shares=e.num_shares,
357 target = download.Data()
358 fd = download.FileDownloader(client, URI, target)
360 # we manually cycle the FileDownloader through a number of steps that
361 # would normally be sequenced by a Deferred chain in
362 # FileDownloader.start(), to give us more control over the process.
363 # In particular, by bypassing _get_all_shareholders, we skip
364 # permuted-peerlist selection.
365 for shnum, bucket in shareholders.items():
366 if shnum < AVAILABLE_SHARES and bucket.closed:
367 fd.add_share_bucket(shnum, bucket)
368 fd._got_all_shareholders(None)
370 # Make it possible to obtain uri_extension from the shareholders.
371 # Arrange for shareholders[0] to be the first, so we can selectively
372 # corrupt the data it returns.
373 fd._uri_extension_sources = shareholders.values()
374 fd._uri_extension_sources.remove(shareholders[0])
375 fd._uri_extension_sources.insert(0, shareholders[0])
377 d = defer.succeed(None)
379 # have the FileDownloader retrieve a copy of uri_extension itself
380 d.addCallback(fd._obtain_uri_extension)
382 if "corrupt_crypttext_hashes" in recover_mode:
383 # replace everybody's crypttext hash trees with a different one
384 # (computed over a different file), then modify our uri_extension
385 # to reflect the new crypttext hash tree root
386 def _corrupt_crypttext_hashes(uri_extension):
387 assert isinstance(uri_extension, dict)
388 assert 'crypttext_root_hash' in uri_extension
389 badhash = hashutil.tagged_hash("bogus", "data")
390 bad_crypttext_hashes = [badhash] * uri_extension['num_segments']
391 badtree = hashtree.HashTree(bad_crypttext_hashes)
392 for bucket in shareholders.values():
393 bucket.crypttext_hashes = list(badtree)
394 uri_extension['crypttext_root_hash'] = badtree[0]
396 d.addCallback(_corrupt_crypttext_hashes)
398 d.addCallback(fd._got_uri_extension)
400 # also have the FileDownloader ask for hash trees
401 d.addCallback(fd._get_hashtrees)
403 d.addCallback(fd._create_validated_buckets)
404 d.addCallback(fd._download_all_segments)
405 d.addCallback(fd._done)
411 def test_not_enough_shares(self):
412 d = self.send_and_recover((4,8,10), AVAILABLE_SHARES=2)
414 self.failUnless(isinstance(res, Failure))
415 self.failUnless(res.check(download.NotEnoughPeersError))
419 def test_one_share_per_peer(self):
420 return self.send_and_recover()
423 return self.send_and_recover(datalen=74)
425 return self.send_and_recover(datalen=75)
427 return self.send_and_recover(datalen=51)
430 return self.send_and_recover(datalen=99)
432 return self.send_and_recover(datalen=100)
434 return self.send_and_recover(datalen=76)
437 return self.send_and_recover(datalen=124)
439 return self.send_and_recover(datalen=125)
441 return self.send_and_recover(datalen=101)
443 # the following tests all use 4-out-of-10 encoding
445 def test_bad_blocks(self):
446 # the first 6 servers have bad blocks, which will be caught by the
448 modemap = dict([(i, "bad block")
451 for i in range(6, 10)])
452 return self.send_and_recover((4,8,10), bucket_modes=modemap)
454 def test_bad_blocks_failure(self):
455 # the first 7 servers have bad blocks, which will be caught by the
456 # blockhashes, and the download will fail
457 modemap = dict([(i, "bad block")
460 for i in range(7, 10)])
461 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
463 self.failUnless(isinstance(res, Failure))
464 self.failUnless(res.check(download.NotEnoughPeersError))
468 def test_bad_blockhashes(self):
469 # the first 6 servers have bad block hashes, so the blockhash tree
471 modemap = dict([(i, "bad blockhash")
474 for i in range(6, 10)])
475 return self.send_and_recover((4,8,10), bucket_modes=modemap)
477 def test_bad_blockhashes_failure(self):
478 # the first 7 servers have bad block hashes, so the blockhash tree
479 # will not validate, and the download will fail
480 modemap = dict([(i, "bad blockhash")
483 for i in range(7, 10)])
484 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
486 self.failUnless(isinstance(res, Failure))
487 self.failUnless(res.check(download.NotEnoughPeersError))
491 def test_bad_sharehashes(self):
492 # the first 6 servers have bad block hashes, so the sharehash tree
494 modemap = dict([(i, "bad sharehash")
497 for i in range(6, 10)])
498 return self.send_and_recover((4,8,10), bucket_modes=modemap)
500 def assertFetchFailureIn(self, fd, where):
501 expected = {"uri_extension": 0,
502 "plaintext_hashroot": 0,
503 "plaintext_hashtree": 0,
504 "crypttext_hashroot": 0,
505 "crypttext_hashtree": 0,
507 if where is not None:
509 self.failUnlessEqual(fd._fetch_failures, expected)
512 # just to make sure the test harness works when we aren't
513 # intentionally causing failures
514 modemap = dict([(i, "good") for i in range(0, 10)])
515 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
516 d.addCallback(self.assertFetchFailureIn, None)
519 def test_bad_uri_extension(self):
520 # the first server has a bad uri_extension block, so we will fail
521 # over to a different server.
522 modemap = dict([(i, "bad uri_extension") for i in range(1)] +
523 [(i, "good") for i in range(1, 10)])
524 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
525 d.addCallback(self.assertFetchFailureIn, "uri_extension")
528 def test_bad_plaintext_hashroot(self):
529 # the first server has a bad plaintext hashroot, so we will fail over
530 # to a different server.
531 modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] +
532 [(i, "good") for i in range(1, 10)])
533 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
534 d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot")
537 def test_bad_crypttext_hashroot(self):
538 # the first server has a bad crypttext hashroot, so we will fail
539 # over to a different server.
540 modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] +
541 [(i, "good") for i in range(1, 10)])
542 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
543 d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot")
546 def test_bad_plaintext_hashes(self):
547 # the first server has a bad plaintext hash block, so we will fail
548 # over to a different server.
549 modemap = dict([(i, "bad plaintext hash") for i in range(1)] +
550 [(i, "good") for i in range(1, 10)])
551 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
552 d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree")
555 def test_bad_crypttext_hashes(self):
556 # the first server has a bad crypttext hash block, so we will fail
557 # over to a different server.
558 modemap = dict([(i, "bad crypttext hash") for i in range(1)] +
559 [(i, "good") for i in range(1, 10)])
560 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
561 d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree")
564 def test_bad_crypttext_hashes_failure(self):
565 # to test that the crypttext merkle tree is really being applied, we
566 # sneak into the download process and corrupt two things: we replace
567 # everybody's crypttext hashtree with a bad version (computed over
568 # bogus data), and we modify the supposedly-validated uri_extension
569 # block to match the new crypttext hashtree root. The download
570 # process should notice that the crypttext coming out of FEC doesn't
571 # match the tree, and fail.
573 modemap = dict([(i, "good") for i in range(0, 10)])
574 d = self.send_and_recover((4,8,10), bucket_modes=modemap,
575 recover_mode=("corrupt_crypttext_hashes"))
577 self.failUnless(isinstance(res, Failure))
578 self.failUnless(res.check(hashtree.BadHashError), res)
583 def test_bad_plaintext(self):
584 # faking a decryption failure is easier: just corrupt the key
585 modemap = dict([(i, "good") for i in range(0, 10)])
586 d = self.send_and_recover((4,8,10), bucket_modes=modemap,
587 recover_mode=("corrupt_key"))
589 self.failUnless(isinstance(res, Failure))
590 self.failUnless(res.check(hashtree.BadHashError))
594 def test_bad_sharehashes_failure(self):
595 # the first 7 servers have bad block hashes, so the sharehash tree
596 # will not validate, and the download will fail
597 modemap = dict([(i, "bad sharehash")
600 for i in range(7, 10)])
601 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
603 self.failUnless(isinstance(res, Failure))
604 self.failUnless(res.check(download.NotEnoughPeersError))
608 def test_missing_sharehashes(self):
609 # the first 6 servers are missing their sharehashes, so the
610 # sharehash tree will not validate
611 modemap = dict([(i, "missing sharehash")
614 for i in range(6, 10)])
615 return self.send_and_recover((4,8,10), bucket_modes=modemap)
617 def test_missing_sharehashes_failure(self):
618 # the first 7 servers are missing their sharehashes, so the
619 # sharehash tree will not validate, and the download will fail
620 modemap = dict([(i, "missing sharehash")
623 for i in range(7, 10)])
624 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
626 self.failUnless(isinstance(res, Failure))
627 self.failUnless(res.check(download.NotEnoughPeersError))
631 def test_lost_one_shareholder(self):
632 # we have enough shareholders when we start, but one segment in we
633 # lose one of them. The upload should still succeed, as long as we
634 # still have 'shares_of_happiness' peers left.
635 modemap = dict([(i, "good") for i in range(9)] +
636 [(i, "lost") for i in range(9, 10)])
637 return self.send_and_recover((4,8,10), bucket_modes=modemap)
639 def test_lost_many_shareholders(self):
640 # we have enough shareholders when we start, but one segment in we
641 # lose all but one of them. The upload should fail.
642 modemap = dict([(i, "good") for i in range(1)] +
643 [(i, "lost") for i in range(1, 10)])
644 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
646 self.failUnless(isinstance(res, Failure))
647 self.failUnless(res.check(encode.NotEnoughPeersError))
651 def test_lost_all_shareholders(self):
652 # we have enough shareholders when we start, but one segment in we
653 # lose all of them. The upload should fail.
654 modemap = dict([(i, "lost") for i in range(10)])
655 d = self.send_and_recover((4,8,10), bucket_modes=modemap)
657 self.failUnless(isinstance(res, Failure))
658 self.failUnless(res.check(encode.NotEnoughPeersError))