]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_encode.py
use added secret to protect convergent encryption
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_encode.py
1
2 from zope.interface import implements
3 from twisted.trial import unittest
4 from twisted.internet import defer
5 from twisted.python.failure import Failure
6 from foolscap import eventual
7 from allmydata import encode, upload, download, hashtree, uri
8 from allmydata.util import hashutil
9 from allmydata.util.assertutil import _assert
10 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
11
12 class LostPeerError(Exception):
13     pass
14
15 def flip_bit(good): # flips the last bit
16     return good[:-1] + chr(ord(good[-1]) ^ 0x01)
17
18 class FakeClient:
19     def log(self, *args, **kwargs):
20         pass
21
22 class FakeBucketWriterProxy:
23     implements(IStorageBucketWriter, IStorageBucketReader)
24     # these are used for both reading and writing
25     def __init__(self, mode="good"):
26         self.mode = mode
27         self.blocks = {}
28         self.plaintext_hashes = []
29         self.crypttext_hashes = []
30         self.block_hashes = None
31         self.share_hashes = None
32         self.closed = False
33
34     def get_peerid(self):
35         return "peerid"
36
37     def startIfNecessary(self):
38         return defer.succeed(self)
39     def start(self):
40         if self.mode == "lost-early":
41             f = Failure(LostPeerError("I went away early"))
42             return eventual.fireEventually(f)
43         return defer.succeed(self)
44
45     def put_block(self, segmentnum, data):
46         if self.mode == "lost-early":
47             f = Failure(LostPeerError("I went away early"))
48             return eventual.fireEventually(f)
49         def _try():
50             assert not self.closed
51             assert segmentnum not in self.blocks
52             if self.mode == "lost" and segmentnum >= 1:
53                 raise LostPeerError("I'm going away now")
54             self.blocks[segmentnum] = data
55         return defer.maybeDeferred(_try)
56
57     def put_plaintext_hashes(self, hashes):
58         def _try():
59             assert not self.closed
60             assert not self.plaintext_hashes
61             self.plaintext_hashes = hashes
62         return defer.maybeDeferred(_try)
63
64     def put_crypttext_hashes(self, hashes):
65         def _try():
66             assert not self.closed
67             assert not self.crypttext_hashes
68             self.crypttext_hashes = hashes
69         return defer.maybeDeferred(_try)
70
71     def put_block_hashes(self, blockhashes):
72         def _try():
73             assert not self.closed
74             assert self.block_hashes is None
75             self.block_hashes = blockhashes
76         return defer.maybeDeferred(_try)
77
78     def put_share_hashes(self, sharehashes):
79         def _try():
80             assert not self.closed
81             assert self.share_hashes is None
82             self.share_hashes = sharehashes
83         return defer.maybeDeferred(_try)
84
85     def put_uri_extension(self, uri_extension):
86         def _try():
87             assert not self.closed
88             self.uri_extension = uri_extension
89         return defer.maybeDeferred(_try)
90
91     def close(self):
92         def _try():
93             assert not self.closed
94             self.closed = True
95         return defer.maybeDeferred(_try)
96
97     def abort(self):
98         return defer.succeed(None)
99
100     def get_block(self, blocknum):
101         def _try():
102             assert isinstance(blocknum, (int, long))
103             if self.mode == "bad block":
104                 return flip_bit(self.blocks[blocknum])
105             return self.blocks[blocknum]
106         return defer.maybeDeferred(_try)
107
108     def get_plaintext_hashes(self):
109         def _try():
110             hashes = self.plaintext_hashes[:]
111             if self.mode == "bad plaintext hashroot":
112                 hashes[0] = flip_bit(hashes[0])
113             if self.mode == "bad plaintext hash":
114                 hashes[1] = flip_bit(hashes[1])
115             return hashes
116         return defer.maybeDeferred(_try)
117
118     def get_crypttext_hashes(self):
119         def _try():
120             hashes = self.crypttext_hashes[:]
121             if self.mode == "bad crypttext hashroot":
122                 hashes[0] = flip_bit(hashes[0])
123             if self.mode == "bad crypttext hash":
124                 hashes[1] = flip_bit(hashes[1])
125             return hashes
126         return defer.maybeDeferred(_try)
127
128     def get_block_hashes(self):
129         def _try():
130             if self.mode == "bad blockhash":
131                 hashes = self.block_hashes[:]
132                 hashes[1] = flip_bit(hashes[1])
133                 return hashes
134             return self.block_hashes
135         return defer.maybeDeferred(_try)
136
137     def get_share_hashes(self):
138         def _try():
139             if self.mode == "bad sharehash":
140                 hashes = self.share_hashes[:]
141                 hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
142                 return hashes
143             if self.mode == "missing sharehash":
144                 # one sneaky attack would be to pretend we don't know our own
145                 # sharehash, which could manage to frame someone else.
146                 # download.py is supposed to guard against this case.
147                 return []
148             return self.share_hashes
149         return defer.maybeDeferred(_try)
150
151     def get_uri_extension(self):
152         def _try():
153             if self.mode == "bad uri_extension":
154                 return flip_bit(self.uri_extension)
155             return self.uri_extension
156         return defer.maybeDeferred(_try)
157
158
159 def make_data(length):
160     data = "happy happy joy joy" * 100
161     assert length <= len(data)
162     return data[:length]
163
164 class Encode(unittest.TestCase):
165
166     def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
167                   expected_block_hashes, expected_share_hashes):
168         data = make_data(datalen)
169         # force use of multiple segments
170         e = encode.Encoder()
171         u = upload.Data(data, convergence="some convergence string")
172         u.max_segment_size = max_segment_size
173         u.encoding_param_k = 25
174         u.encoding_param_happy = 75
175         u.encoding_param_n = 100
176         eu = upload.EncryptAnUploadable(u)
177         d = e.set_encrypted_uploadable(eu)
178
179         all_shareholders = []
180         def _ready(res):
181             k,happy,n = e.get_param("share_counts")
182             _assert(n == NUM_SHARES) # else we'll be completely confused
183             numsegs = e.get_param("num_segments")
184             _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS)
185             segsize = e.get_param("segment_size")
186             _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize,
187                      NUM_SEGMENTS, segsize,
188                      (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize)
189
190             shareholders = {}
191             for shnum in range(NUM_SHARES):
192                 peer = FakeBucketWriterProxy()
193                 shareholders[shnum] = peer
194                 all_shareholders.append(peer)
195             e.set_shareholders(shareholders)
196             return e.start()
197         d.addCallback(_ready)
198
199         def _check(res):
200             (uri_extension_hash, required_shares, num_shares, file_size) = res
201             self.failUnless(isinstance(uri_extension_hash, str))
202             self.failUnlessEqual(len(uri_extension_hash), 32)
203             for i,peer in enumerate(all_shareholders):
204                 self.failUnless(peer.closed)
205                 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
206                 # each peer gets a full tree of block hashes. For 3 or 4
207                 # segments, that's 7 hashes. For 5 segments it's 15 hashes.
208                 self.failUnlessEqual(len(peer.block_hashes),
209                                      expected_block_hashes)
210                 for h in peer.block_hashes:
211                     self.failUnlessEqual(len(h), 32)
212                 # each peer also gets their necessary chain of share hashes.
213                 # For 100 shares (rounded up to 128 leaves), that's 8 hashes
214                 self.failUnlessEqual(len(peer.share_hashes),
215                                      expected_share_hashes)
216                 for (hashnum, h) in peer.share_hashes:
217                     self.failUnless(isinstance(hashnum, int))
218                     self.failUnlessEqual(len(h), 32)
219         d.addCallback(_check)
220
221         return d
222
223     # a series of 3*3 tests to check out edge conditions. One axis is how the
224     # plaintext is divided into segments: kn+(-1,0,1). Another way to express
225     # that is that n%k == -1 or 0 or 1. For example, for 25-byte segments, we
226     # might test 74 bytes, 75 bytes, and 76 bytes.
227
228     # on the other axis is how many leaves in the block hash tree we wind up
229     # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
230     # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
231     # segments, and 5 segments.
232
233     # that results in the following series of data lengths:
234     #  3 segs: 74, 75, 51
235     #  4 segs: 99, 100, 76
236     #  5 segs: 124, 125, 101
237
238     # all tests encode to 100 shares, which means the share hash tree will
239     # have 128 leaves, which means that buckets will be given an 8-long share
240     # hash chain
241
242     # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
243     # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
244     # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
245     # trees, which get 15 blockhashes.
246
247     def test_send_74(self):
248         # 3 segments (25, 25, 24)
249         return self.do_encode(25, 74, 100, 3, 7, 8)
250     def test_send_75(self):
251         # 3 segments (25, 25, 25)
252         return self.do_encode(25, 75, 100, 3, 7, 8)
253     def test_send_51(self):
254         # 3 segments (25, 25, 1)
255         return self.do_encode(25, 51, 100, 3, 7, 8)
256
257     def test_send_76(self):
258         # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
259         return self.do_encode(25, 76, 100, 4, 7, 8)
260     def test_send_99(self):
261         # 4 segments: 25,25,25,24
262         return self.do_encode(25, 99, 100, 4, 7, 8)
263     def test_send_100(self):
264         # 4 segments: 25,25,25,25
265         return self.do_encode(25, 100, 100, 4, 7, 8)
266
267     def test_send_124(self):
268         # 5 segments: 25, 25, 25, 25, 24
269         return self.do_encode(25, 124, 100, 5, 15, 8)
270     def test_send_125(self):
271         # 5 segments: 25, 25, 25, 25, 25
272         return self.do_encode(25, 125, 100, 5, 15, 8)
273     def test_send_101(self):
274         # 5 segments: 25, 25, 25, 25, 1
275         return self.do_encode(25, 101, 100, 5, 15, 8)
276
277 class Roundtrip(unittest.TestCase):
278     def send_and_recover(self, k_and_happy_and_n=(25,75,100),
279                          AVAILABLE_SHARES=None,
280                          datalen=76,
281                          max_segment_size=25,
282                          bucket_modes={},
283                          recover_mode="recover",
284                          ):
285         if AVAILABLE_SHARES is None:
286             AVAILABLE_SHARES = k_and_happy_and_n[2]
287         data = make_data(datalen)
288         d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
289                       max_segment_size, bucket_modes, data)
290         # that fires with (uri_extension_hash, e, shareholders)
291         d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
292         # that fires with newdata
293         def _downloaded((newdata, fd)):
294             self.failUnless(newdata == data)
295             return fd
296         d.addCallback(_downloaded)
297         return d
298
299     def send(self, k_and_happy_and_n, AVAILABLE_SHARES, max_segment_size,
300              bucket_modes, data):
301         k, happy, n = k_and_happy_and_n
302         NUM_SHARES = k_and_happy_and_n[2]
303         if AVAILABLE_SHARES is None:
304             AVAILABLE_SHARES = NUM_SHARES
305         e = encode.Encoder()
306         u = upload.Data(data, convergence="some convergence string")
307         # force use of multiple segments by using a low max_segment_size
308         u.max_segment_size = max_segment_size
309         u.encoding_param_k = k
310         u.encoding_param_happy = happy
311         u.encoding_param_n = n
312         eu = upload.EncryptAnUploadable(u)
313         d = e.set_encrypted_uploadable(eu)
314
315         shareholders = {}
316         def _ready(res):
317             k,happy,n = e.get_param("share_counts")
318             assert n == NUM_SHARES # else we'll be completely confused
319             all_peers = []
320             for shnum in range(NUM_SHARES):
321                 mode = bucket_modes.get(shnum, "good")
322                 peer = FakeBucketWriterProxy(mode)
323                 shareholders[shnum] = peer
324             e.set_shareholders(shareholders)
325             return e.start()
326         d.addCallback(_ready)
327         def _sent(res):
328             d1 = u.get_encryption_key()
329             d1.addCallback(lambda key: (res, key, shareholders))
330             return d1
331         d.addCallback(_sent)
332         return d
333
334     def recover(self, (res, key, shareholders), AVAILABLE_SHARES,
335                 recover_mode):
336         (uri_extension_hash, required_shares, num_shares, file_size) = res
337
338         if "corrupt_key" in recover_mode:
339             # we corrupt the key, so that the decrypted data is corrupted and
340             # will fail the plaintext hash check. Since we're manually
341             # attaching shareholders, the fact that the storage index is also
342             # corrupted doesn't matter.
343             key = flip_bit(key)
344
345         u = uri.CHKFileURI(key=key,
346                            uri_extension_hash=uri_extension_hash,
347                            needed_shares=required_shares,
348                            total_shares=num_shares,
349                            size=file_size)
350         URI = u.to_string()
351
352         client = FakeClient()
353         target = download.Data()
354         fd = download.FileDownloader(client, URI, target)
355
356         # we manually cycle the FileDownloader through a number of steps that
357         # would normally be sequenced by a Deferred chain in
358         # FileDownloader.start(), to give us more control over the process.
359         # In particular, by bypassing _get_all_shareholders, we skip
360         # permuted-peerlist selection.
361         for shnum, bucket in shareholders.items():
362             if shnum < AVAILABLE_SHARES and bucket.closed:
363                 fd.add_share_bucket(shnum, bucket)
364         fd._got_all_shareholders(None)
365
366         # Make it possible to obtain uri_extension from the shareholders.
367         # Arrange for shareholders[0] to be the first, so we can selectively
368         # corrupt the data it returns.
369         fd._uri_extension_sources = shareholders.values()
370         fd._uri_extension_sources.remove(shareholders[0])
371         fd._uri_extension_sources.insert(0, shareholders[0])
372
373         d = defer.succeed(None)
374
375         # have the FileDownloader retrieve a copy of uri_extension itself
376         d.addCallback(fd._obtain_uri_extension)
377
378         if "corrupt_crypttext_hashes" in recover_mode:
379             # replace everybody's crypttext hash trees with a different one
380             # (computed over a different file), then modify our uri_extension
381             # to reflect the new crypttext hash tree root
382             def _corrupt_crypttext_hashes(uri_extension):
383                 assert isinstance(uri_extension, dict)
384                 assert 'crypttext_root_hash' in uri_extension
385                 badhash = hashutil.tagged_hash("bogus", "data")
386                 bad_crypttext_hashes = [badhash] * uri_extension['num_segments']
387                 badtree = hashtree.HashTree(bad_crypttext_hashes)
388                 for bucket in shareholders.values():
389                     bucket.crypttext_hashes = list(badtree)
390                 uri_extension['crypttext_root_hash'] = badtree[0]
391                 return uri_extension
392             d.addCallback(_corrupt_crypttext_hashes)
393
394         d.addCallback(fd._got_uri_extension)
395
396         # also have the FileDownloader ask for hash trees
397         d.addCallback(fd._get_hashtrees)
398
399         d.addCallback(fd._create_validated_buckets)
400         d.addCallback(fd._download_all_segments)
401         d.addCallback(fd._done)
402         def _done(newdata):
403             return (newdata, fd)
404         d.addCallback(_done)
405         return d
406
407     def test_not_enough_shares(self):
408         d = self.send_and_recover((4,8,10), AVAILABLE_SHARES=2)
409         def _done(res):
410             self.failUnless(isinstance(res, Failure))
411             self.failUnless(res.check(download.NotEnoughPeersError))
412         d.addBoth(_done)
413         return d
414
415     def test_one_share_per_peer(self):
416         return self.send_and_recover()
417
418     def test_74(self):
419         return self.send_and_recover(datalen=74)
420     def test_75(self):
421         return self.send_and_recover(datalen=75)
422     def test_51(self):
423         return self.send_and_recover(datalen=51)
424
425     def test_99(self):
426         return self.send_and_recover(datalen=99)
427     def test_100(self):
428         return self.send_and_recover(datalen=100)
429     def test_76(self):
430         return self.send_and_recover(datalen=76)
431
432     def test_124(self):
433         return self.send_and_recover(datalen=124)
434     def test_125(self):
435         return self.send_and_recover(datalen=125)
436     def test_101(self):
437         return self.send_and_recover(datalen=101)
438
439     # the following tests all use 4-out-of-10 encoding
440
441     def test_bad_blocks(self):
442         # the first 6 servers have bad blocks, which will be caught by the
443         # blockhashes
444         modemap = dict([(i, "bad block")
445                         for i in range(6)]
446                        + [(i, "good")
447                           for i in range(6, 10)])
448         return self.send_and_recover((4,8,10), bucket_modes=modemap)
449
450     def test_bad_blocks_failure(self):
451         # the first 7 servers have bad blocks, which will be caught by the
452         # blockhashes, and the download will fail
453         modemap = dict([(i, "bad block")
454                         for i in range(7)]
455                        + [(i, "good")
456                           for i in range(7, 10)])
457         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
458         def _done(res):
459             self.failUnless(isinstance(res, Failure))
460             self.failUnless(res.check(download.NotEnoughPeersError))
461         d.addBoth(_done)
462         return d
463
464     def test_bad_blockhashes(self):
465         # the first 6 servers have bad block hashes, so the blockhash tree
466         # will not validate
467         modemap = dict([(i, "bad blockhash")
468                         for i in range(6)]
469                        + [(i, "good")
470                           for i in range(6, 10)])
471         return self.send_and_recover((4,8,10), bucket_modes=modemap)
472
473     def test_bad_blockhashes_failure(self):
474         # the first 7 servers have bad block hashes, so the blockhash tree
475         # will not validate, and the download will fail
476         modemap = dict([(i, "bad blockhash")
477                         for i in range(7)]
478                        + [(i, "good")
479                           for i in range(7, 10)])
480         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
481         def _done(res):
482             self.failUnless(isinstance(res, Failure))
483             self.failUnless(res.check(download.NotEnoughPeersError))
484         d.addBoth(_done)
485         return d
486
487     def test_bad_sharehashes(self):
488         # the first 6 servers have bad block hashes, so the sharehash tree
489         # will not validate
490         modemap = dict([(i, "bad sharehash")
491                         for i in range(6)]
492                        + [(i, "good")
493                           for i in range(6, 10)])
494         return self.send_and_recover((4,8,10), bucket_modes=modemap)
495
496     def assertFetchFailureIn(self, fd, where):
497         expected = {"uri_extension": 0,
498                     "plaintext_hashroot": 0,
499                     "plaintext_hashtree": 0,
500                     "crypttext_hashroot": 0,
501                     "crypttext_hashtree": 0,
502                     }
503         if where is not None:
504             expected[where] += 1
505         self.failUnlessEqual(fd._fetch_failures, expected)
506
507     def test_good(self):
508         # just to make sure the test harness works when we aren't
509         # intentionally causing failures
510         modemap = dict([(i, "good") for i in range(0, 10)])
511         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
512         d.addCallback(self.assertFetchFailureIn, None)
513         return d
514
515     def test_bad_uri_extension(self):
516         # the first server has a bad uri_extension block, so we will fail
517         # over to a different server.
518         modemap = dict([(i, "bad uri_extension") for i in range(1)] +
519                        [(i, "good") for i in range(1, 10)])
520         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
521         d.addCallback(self.assertFetchFailureIn, "uri_extension")
522         return d
523
524     def test_bad_sharehashes_failure(self):
525         # the first 7 servers have bad block hashes, so the sharehash tree
526         # will not validate, and the download will fail
527         modemap = dict([(i, "bad sharehash")
528                         for i in range(7)]
529                        + [(i, "good")
530                           for i in range(7, 10)])
531         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
532         def _done(res):
533             self.failUnless(isinstance(res, Failure))
534             self.failUnless(res.check(download.NotEnoughPeersError))
535         d.addBoth(_done)
536         return d
537
538     def test_missing_sharehashes(self):
539         # the first 6 servers are missing their sharehashes, so the
540         # sharehash tree will not validate
541         modemap = dict([(i, "missing sharehash")
542                         for i in range(6)]
543                        + [(i, "good")
544                           for i in range(6, 10)])
545         return self.send_and_recover((4,8,10), bucket_modes=modemap)
546
547     def test_missing_sharehashes_failure(self):
548         # the first 7 servers are missing their sharehashes, so the
549         # sharehash tree will not validate, and the download will fail
550         modemap = dict([(i, "missing sharehash")
551                         for i in range(7)]
552                        + [(i, "good")
553                           for i in range(7, 10)])
554         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
555         def _done(res):
556             self.failUnless(isinstance(res, Failure))
557             self.failUnless(res.check(download.NotEnoughPeersError))
558         d.addBoth(_done)
559         return d
560
561     def test_lost_one_shareholder(self):
562         # we have enough shareholders when we start, but one segment in we
563         # lose one of them. The upload should still succeed, as long as we
564         # still have 'shares_of_happiness' peers left.
565         modemap = dict([(i, "good") for i in range(9)] +
566                        [(i, "lost") for i in range(9, 10)])
567         return self.send_and_recover((4,8,10), bucket_modes=modemap)
568
569     def test_lost_one_shareholder_early(self):
570         # we have enough shareholders when we choose peers, but just before
571         # we send the 'start' message, we lose one of them. The upload should
572         # still succeed, as long as we still have 'shares_of_happiness' peers
573         # left.
574         modemap = dict([(i, "good") for i in range(9)] +
575                        [(i, "lost-early") for i in range(9, 10)])
576         return self.send_and_recover((4,8,10), bucket_modes=modemap)
577
578     def test_lost_many_shareholders(self):
579         # we have enough shareholders when we start, but one segment in we
580         # lose all but one of them. The upload should fail.
581         modemap = dict([(i, "good") for i in range(1)] +
582                        [(i, "lost") for i in range(1, 10)])
583         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
584         def _done(res):
585             self.failUnless(isinstance(res, Failure))
586             self.failUnless(res.check(encode.NotEnoughPeersError), res)
587         d.addBoth(_done)
588         return d
589
590     def test_lost_all_shareholders(self):
591         # we have enough shareholders when we start, but one segment in we
592         # lose all of them. The upload should fail.
593         modemap = dict([(i, "lost") for i in range(10)])
594         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
595         def _done(res):
596             self.failUnless(isinstance(res, Failure))
597             self.failUnless(res.check(encode.NotEnoughPeersError))
598         d.addBoth(_done)
599         return d
600