]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_encode.py
test_encode.py: refactor send_and_recover a bit
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_encode.py
1 #! /usr/bin/env python
2
3 from twisted.trial import unittest
4 from twisted.internet import defer
5 from twisted.python.failure import Failure
6 from foolscap import eventual
7 from allmydata import encode, download
8 from allmydata.util import bencode
9 from allmydata.uri import pack_uri
10 from cStringIO import StringIO
11
12 class FakePeer:
13     def __init__(self, mode="good"):
14         self.ss = FakeStorageServer(mode)
15
16     def callRemote(self, methname, *args, **kwargs):
17         def _call():
18             meth = getattr(self, methname)
19             return meth(*args, **kwargs)
20         return defer.maybeDeferred(_call)
21
22     def get_service(self, sname):
23         assert sname == "storageserver"
24         return self.ss
25
26 class FakeStorageServer:
27     def __init__(self, mode):
28         self.mode = mode
29     def callRemote(self, methname, *args, **kwargs):
30         def _call():
31             meth = getattr(self, methname)
32             return meth(*args, **kwargs)
33         d = eventual.fireEventually()
34         d.addCallback(lambda res: _call())
35         return d
36     def allocate_buckets(self, verifierid, sharenums, shareize, blocksize, canary):
37         if self.mode == "full":
38             return (set(), {},)
39         elif self.mode == "already got them":
40             return (set(sharenums), {},)
41         else:
42             return (set(), dict([(shnum, FakeBucketWriter(),) for shnum in sharenums]),)
43
44 class LostPeerError(Exception):
45     pass
46
47 class FakeBucketWriter:
48     # these are used for both reading and writing
49     def __init__(self, mode="good"):
50         self.mode = mode
51         self.blocks = {}
52         self.plaintext_hashes = None
53         self.crypttext_hashes = None
54         self.block_hashes = None
55         self.share_hashes = None
56         self.closed = False
57
58     def callRemote(self, methname, *args, **kwargs):
59         def _call():
60             meth = getattr(self, methname)
61             return meth(*args, **kwargs)
62         return defer.maybeDeferred(_call)
63
64     def put_block(self, segmentnum, data):
65         assert not self.closed
66         assert segmentnum not in self.blocks
67         if self.mode == "lost" and segmentnum >= 1:
68             raise LostPeerError("I'm going away now")
69         self.blocks[segmentnum] = data
70
71     def put_plaintext_hashes(self, hashes):
72         assert not self.closed
73         assert self.plaintext_hashes is None
74         self.plaintext_hashes = hashes
75
76     def put_crypttext_hashes(self, hashes):
77         assert not self.closed
78         assert self.crypttext_hashes is None
79         self.crypttext_hashes = hashes
80
81     def put_block_hashes(self, blockhashes):
82         assert not self.closed
83         assert self.block_hashes is None
84         self.block_hashes = blockhashes
85         
86     def put_share_hashes(self, sharehashes):
87         assert not self.closed
88         assert self.share_hashes is None
89         self.share_hashes = sharehashes
90
91     def put_thingA(self, thingA):
92         assert not self.closed
93         self.thingA = thingA
94
95     def close(self):
96         assert not self.closed
97         self.closed = True
98
99     def flip_bit(self, good):
100         return good[:-1] + chr(ord(good[-1]) ^ 0x01)
101
102     def get_block(self, blocknum):
103         assert isinstance(blocknum, (int, long))
104         if self.mode == "bad block":
105             return self.flip_bit(self.blocks[blocknum])
106         return self.blocks[blocknum]
107
108     def get_plaintext_hashes(self):
109         if self.mode == "bad plaintexthash":
110             hashes = self.plaintext_hashes[:]
111             hashes[1] = self.flip_bit(hashes[1])
112             return hashes
113         return self.plaintext_hashes
114     def get_crypttext_hashes(self):
115         if self.mode == "bad crypttexthash":
116             hashes = self.crypttext_hashes[:]
117             hashes[1] = self.flip_bit(hashes[1])
118             return hashes
119         return self.crypttext_hashes
120
121     def get_block_hashes(self):
122         if self.mode == "bad blockhash":
123             hashes = self.block_hashes[:]
124             hashes[1] = self.flip_bit(hashes[1])
125             return hashes
126         return self.block_hashes
127     def get_share_hashes(self):
128         if self.mode == "bad sharehash":
129             hashes = self.share_hashes[:]
130             hashes[1] = (hashes[1][0], self.flip_bit(hashes[1][1]))
131             return hashes
132         if self.mode == "missing sharehash":
133             # one sneaky attack would be to pretend we don't know our own
134             # sharehash, which could manage to frame someone else.
135             # download.py is supposed to guard against this case.
136             return []
137         return self.share_hashes
138
139
140 def make_data(length):
141     data = "happy happy joy joy" * 100
142     assert length <= len(data)
143     return data[:length]
144
145 class Encode(unittest.TestCase):
146
147     def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
148                   expected_block_hashes, expected_share_hashes):
149         data = make_data(datalen)
150         # force use of multiple segments
151         options = {"max_segment_size": max_segment_size}
152         e = encode.Encoder(options)
153         nonkey = "\x00" * 16
154         e.setup(StringIO(data), nonkey)
155         assert e.num_shares == NUM_SHARES # else we'll be completely confused
156         e.setup_codec() # need to rebuild the codec for that change
157         assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
158         shareholders = {}
159         all_shareholders = []
160         for shnum in range(NUM_SHARES):
161             peer = FakeBucketWriter()
162             shareholders[shnum] = peer
163             all_shareholders.append(peer)
164         e.set_shareholders(shareholders)
165         d = e.start()
166         def _check(roothash):
167             self.failUnless(isinstance(roothash, str))
168             self.failUnlessEqual(len(roothash), 32)
169             for i,peer in enumerate(all_shareholders):
170                 self.failUnless(peer.closed)
171                 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
172                 # each peer gets a full tree of block hashes. For 3 or 4
173                 # segments, that's 7 hashes. For 5 segments it's 15 hashes.
174                 self.failUnlessEqual(len(peer.block_hashes),
175                                      expected_block_hashes)
176                 for h in peer.block_hashes:
177                     self.failUnlessEqual(len(h), 32)
178                 # each peer also gets their necessary chain of share hashes.
179                 # For 100 shares (rounded up to 128 leaves), that's 8 hashes
180                 self.failUnlessEqual(len(peer.share_hashes),
181                                      expected_share_hashes)
182                 for (hashnum, h) in peer.share_hashes:
183                     self.failUnless(isinstance(hashnum, int))
184                     self.failUnlessEqual(len(h), 32)
185         d.addCallback(_check)
186
187         return d
188
189     # a series of 3*3 tests to check out edge conditions. One axis is how the
190     # plaintext is divided into segments: kn+(-1,0,1). Another way to express
191     # that is that n%k == -1 or 0 or 1. For example, for 25-byte segments, we
192     # might test 74 bytes, 75 bytes, and 76 bytes.
193
194     # on the other axis is how many leaves in the block hash tree we wind up
195     # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
196     # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
197     # segments, and 5 segments.
198
199     # that results in the following series of data lengths:
200     #  3 segs: 74, 75, 51
201     #  4 segs: 99, 100, 76
202     #  5 segs: 124, 125, 101
203
204     # all tests encode to 100 shares, which means the share hash tree will
205     # have 128 leaves, which means that buckets will be given an 8-long share
206     # hash chain
207     
208     # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
209     # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
210     # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
211     # trees, which get 15 blockhashes.
212
213     def test_send_74(self):
214         # 3 segments (25, 25, 24)
215         return self.do_encode(25, 74, 100, 3, 7, 8)
216     def test_send_75(self):
217         # 3 segments (25, 25, 25)
218         return self.do_encode(25, 75, 100, 3, 7, 8)
219     def test_send_51(self):
220         # 3 segments (25, 25, 1)
221         return self.do_encode(25, 51, 100, 3, 7, 8)
222
223     def test_send_76(self):
224         # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
225         return self.do_encode(25, 76, 100, 4, 7, 8)
226     def test_send_99(self):
227         # 4 segments: 25,25,25,24
228         return self.do_encode(25, 99, 100, 4, 7, 8)
229     def test_send_100(self):
230         # 4 segments: 25,25,25,25
231         return self.do_encode(25, 100, 100, 4, 7, 8)
232
233     def test_send_101(self):
234         # encode a 101 byte file (in 5 segments: 25,25,25,25,1) to 100 shares
235         return self.do_encode(25, self.make_data(101), 100, 5, 15, 8)
236
237     def test_send_124(self):
238         # 5 segments: 25, 25, 25, 25, 24
239         return self.do_encode(25, 124, 100, 5, 15, 8)
240     def test_send_125(self):
241         # 5 segments: 25, 25, 25, 25, 25
242         return self.do_encode(25, 125, 100, 5, 15, 8)
243     def test_send_101(self):
244         # 5 segments: 25, 25, 25, 25, 1
245         return self.do_encode(25, 101, 100, 5, 15, 8)
246
247 class Roundtrip(unittest.TestCase):
248     def send_and_recover(self, k_and_happy_and_n=(25,75,100),
249                          AVAILABLE_SHARES=None,
250                          datalen=76,
251                          max_segment_size=25,
252                          bucket_modes={},
253                          ):
254         NUM_SHARES = k_and_happy_and_n[2]
255         if AVAILABLE_SHARES is None:
256             AVAILABLE_SHARES = NUM_SHARES
257         data = make_data(datalen)
258         # force use of multiple segments
259         options = {"max_segment_size": max_segment_size,
260                    "needed_and_happy_and_total_shares": k_and_happy_and_n}
261         e = encode.Encoder(options)
262         nonkey = "\x00" * 16
263         e.setup(StringIO(data), nonkey)
264
265         assert e.num_shares == NUM_SHARES # else we'll be completely confused
266         e.setup_codec() # need to rebuild the codec for that change
267
268         shareholders = {}
269         all_peers = []
270         for shnum in range(NUM_SHARES):
271             mode = bucket_modes.get(shnum, "good")
272             peer = FakeBucketWriter(mode)
273             shareholders[shnum] = peer
274         e.set_shareholders(shareholders)
275         e.set_thingA_data({'verifierid': "V" * 20,
276                            'fileid': "F" * 20,
277                            })
278         d = e.start()
279         d.addCallback(self.recover, nonkey, e, shareholders, AVAILABLE_SHARES)
280         def _downloaded(newdata):
281             self.failUnless(newdata == data)
282         d.addCallback(_downloaded)
283         return d
284
285     def recover(self, thingA_hash, nonkey, e, shareholders, AVAILABLE_SHARES):
286         URI = pack_uri(storage_index="S" * 20,
287                        key=nonkey,
288                        thingA_hash=thingA_hash,
289                        needed_shares=e.required_shares,
290                        total_shares=e.num_shares,
291                        size=e.file_size)
292         client = None
293         target = download.Data()
294         fd = download.FileDownloader(client, URI, target)
295         fd.check_verifierid = False
296         fd.check_fileid = False
297         # grab a copy of thingA from one of the shareholders
298         thingA = shareholders[0].thingA
299         thingA_data = bencode.bdecode(thingA)
300         NOTthingA = {'codec_name': e._codec.get_encoder_type(),
301                   'codec_params': e._codec.get_serialized_params(),
302                   'tail_codec_params': e._tail_codec.get_serialized_params(),
303                   'verifierid': "V" * 20,
304                   'fileid': "F" * 20,
305                      #'share_root_hash': roothash,
306                   'segment_size': e.segment_size,
307                   'needed_shares': e.required_shares,
308                   'total_shares': e.num_shares,
309                   }
310         fd._got_thingA(thingA_data)
311         for shnum, bucket in shareholders.items():
312             if shnum < AVAILABLE_SHARES and bucket.closed:
313                 fd.add_share_bucket(shnum, bucket)
314         fd._got_all_shareholders(None)
315         fd._create_validated_buckets(None)
316         d = fd._download_all_segments(None)
317         d.addCallback(fd._done)
318         return d
319
320     def test_not_enough_shares(self):
321         d = self.send_and_recover((4,8,10), AVAILABLE_SHARES=2)
322         def _done(res):
323             self.failUnless(isinstance(res, Failure))
324             self.failUnless(res.check(download.NotEnoughPeersError))
325         d.addBoth(_done)
326         return d
327
328     def test_one_share_per_peer(self):
329         return self.send_and_recover()
330
331     def test_74(self):
332         return self.send_and_recover(datalen=74)
333     def test_75(self):
334         return self.send_and_recover(datalen=75)
335     def test_51(self):
336         return self.send_and_recover(datalen=51)
337
338     def test_99(self):
339         return self.send_and_recover(datalen=99)
340     def test_100(self):
341         return self.send_and_recover(datalen=100)
342     def test_76(self):
343         return self.send_and_recover(datalen=76)
344
345     def test_124(self):
346         return self.send_and_recover(datalen=124)
347     def test_125(self):
348         return self.send_and_recover(datalen=125)
349     def test_101(self):
350         return self.send_and_recover(datalen=101)
351
352     # the following tests all use 4-out-of-10 encoding
353
354     def test_bad_blocks(self):
355         # the first 6 servers have bad blocks, which will be caught by the
356         # blockhashes
357         modemap = dict([(i, "bad block")
358                         for i in range(6)]
359                        + [(i, "good")
360                           for i in range(6, 10)])
361         return self.send_and_recover((4,8,10), bucket_modes=modemap)
362
363     def test_bad_blocks_failure(self):
364         # the first 7 servers have bad blocks, which will be caught by the
365         # blockhashes, and the download will fail
366         modemap = dict([(i, "bad block")
367                         for i in range(7)]
368                        + [(i, "good")
369                           for i in range(7, 10)])
370         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
371         def _done(res):
372             self.failUnless(isinstance(res, Failure))
373             self.failUnless(res.check(download.NotEnoughPeersError))
374         d.addBoth(_done)
375         return d
376
377     def test_bad_blockhashes(self):
378         # the first 6 servers have bad block hashes, so the blockhash tree
379         # will not validate
380         modemap = dict([(i, "bad blockhash")
381                         for i in range(6)]
382                        + [(i, "good")
383                           for i in range(6, 10)])
384         return self.send_and_recover((4,8,10), bucket_modes=modemap)
385
386     def test_bad_blockhashes_failure(self):
387         # the first 7 servers have bad block hashes, so the blockhash tree
388         # will not validate, and the download will fail
389         modemap = dict([(i, "bad blockhash")
390                         for i in range(7)]
391                        + [(i, "good")
392                           for i in range(7, 10)])
393         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
394         def _done(res):
395             self.failUnless(isinstance(res, Failure))
396             self.failUnless(res.check(download.NotEnoughPeersError))
397         d.addBoth(_done)
398         return d
399
400     def test_bad_sharehashes(self):
401         # the first 6 servers have bad block hashes, so the sharehash tree
402         # will not validate
403         modemap = dict([(i, "bad sharehash")
404                         for i in range(6)]
405                        + [(i, "good")
406                           for i in range(6, 10)])
407         return self.send_and_recover((4,8,10), bucket_modes=modemap)
408
409     def test_bad_sharehashes_failure(self):
410         # the first 7 servers have bad block hashes, so the sharehash tree
411         # will not validate, and the download will fail
412         modemap = dict([(i, "bad sharehash")
413                         for i in range(7)]
414                        + [(i, "good")
415                           for i in range(7, 10)])
416         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
417         def _done(res):
418             self.failUnless(isinstance(res, Failure))
419             self.failUnless(res.check(download.NotEnoughPeersError))
420         d.addBoth(_done)
421         return d
422
423     def test_missing_sharehashes(self):
424         # the first 6 servers are missing their sharehashes, so the
425         # sharehash tree will not validate
426         modemap = dict([(i, "missing sharehash")
427                         for i in range(6)]
428                        + [(i, "good")
429                           for i in range(6, 10)])
430         return self.send_and_recover((4,8,10), bucket_modes=modemap)
431
432     def test_missing_sharehashes_failure(self):
433         # the first 7 servers are missing their sharehashes, so the
434         # sharehash tree will not validate, and the download will fail
435         modemap = dict([(i, "missing sharehash")
436                         for i in range(7)]
437                        + [(i, "good")
438                           for i in range(7, 10)])
439         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
440         def _done(res):
441             self.failUnless(isinstance(res, Failure))
442             self.failUnless(res.check(download.NotEnoughPeersError))
443         d.addBoth(_done)
444         return d
445
446     def test_lost_one_shareholder(self):
447         # we have enough shareholders when we start, but one segment in we
448         # lose one of them. The upload should still succeed, as long as we
449         # still have 'shares_of_happiness' peers left.
450         modemap = dict([(i, "good") for i in range(9)] +
451                        [(i, "lost") for i in range(9, 10)])
452         return self.send_and_recover((4,8,10), bucket_modes=modemap)
453
454     def test_lost_many_shareholders(self):
455         # we have enough shareholders when we start, but one segment in we
456         # lose all but one of them. The upload should fail.
457         modemap = dict([(i, "good") for i in range(1)] +
458                        [(i, "lost") for i in range(1, 10)])
459         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
460         def _done(res):
461             self.failUnless(isinstance(res, Failure))
462             self.failUnless(res.check(encode.NotEnoughPeersError))
463         d.addBoth(_done)
464         return d
465
466     def test_lost_all_shareholders(self):
467         # we have enough shareholders when we start, but one segment in we
468         # lose all of them. The upload should fail.
469         modemap = dict([(i, "lost") for i in range(10)])
470         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
471         def _done(res):
472             self.failUnless(isinstance(res, Failure))
473             self.failUnless(res.check(encode.NotEnoughPeersError))
474         d.addBoth(_done)
475         return d
476