]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_encode.py
switch UploadResults to use get_uri(), hide internal ._uri
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_encode.py
1 from zope.interface import implements
2 from twisted.trial import unittest
3 from twisted.internet import defer
4 from twisted.python.failure import Failure
5 from foolscap.api import fireEventually
6 from allmydata import uri
7 from allmydata.immutable import encode, upload, checker
8 from allmydata.util import hashutil
9 from allmydata.util.assertutil import _assert
10 from allmydata.util.consumer import download_to_data
11 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
12 from allmydata.test.no_network import GridTestMixin
13
14 class LostPeerError(Exception):
15     pass
16
17 def flip_bit(good): # flips the last bit
18     return good[:-1] + chr(ord(good[-1]) ^ 0x01)
19
20 class FakeBucketReaderWriterProxy:
21     implements(IStorageBucketWriter, IStorageBucketReader)
22     # these are used for both reading and writing
23     def __init__(self, mode="good", peerid="peer"):
24         self.mode = mode
25         self.blocks = {}
26         self.plaintext_hashes = []
27         self.crypttext_hashes = []
28         self.block_hashes = None
29         self.share_hashes = None
30         self.closed = False
31         self.peerid = peerid
32
33     def get_peerid(self):
34         return self.peerid
35
36     def _start(self):
37         if self.mode == "lost-early":
38             f = Failure(LostPeerError("I went away early"))
39             return fireEventually(f)
40         return defer.succeed(self)
41
42     def put_header(self):
43         return self._start()
44
45     def put_block(self, segmentnum, data):
46         if self.mode == "lost-early":
47             f = Failure(LostPeerError("I went away early"))
48             return fireEventually(f)
49         def _try():
50             assert not self.closed
51             assert segmentnum not in self.blocks
52             if self.mode == "lost" and segmentnum >= 1:
53                 raise LostPeerError("I'm going away now")
54             self.blocks[segmentnum] = data
55         return defer.maybeDeferred(_try)
56
57     def put_crypttext_hashes(self, hashes):
58         def _try():
59             assert not self.closed
60             assert not self.crypttext_hashes
61             self.crypttext_hashes = hashes
62         return defer.maybeDeferred(_try)
63
64     def put_block_hashes(self, blockhashes):
65         def _try():
66             assert not self.closed
67             assert self.block_hashes is None
68             self.block_hashes = blockhashes
69         return defer.maybeDeferred(_try)
70
71     def put_share_hashes(self, sharehashes):
72         def _try():
73             assert not self.closed
74             assert self.share_hashes is None
75             self.share_hashes = sharehashes
76         return defer.maybeDeferred(_try)
77
78     def put_uri_extension(self, uri_extension):
79         def _try():
80             assert not self.closed
81             self.uri_extension = uri_extension
82         return defer.maybeDeferred(_try)
83
84     def close(self):
85         def _try():
86             assert not self.closed
87             self.closed = True
88         return defer.maybeDeferred(_try)
89
90     def abort(self):
91         return defer.succeed(None)
92
93     def get_block_data(self, blocknum, blocksize, size):
94         d = self._start()
95         def _try(unused=None):
96             assert isinstance(blocknum, (int, long))
97             if self.mode == "bad block":
98                 return flip_bit(self.blocks[blocknum])
99             return self.blocks[blocknum]
100         d.addCallback(_try)
101         return d
102
103     def get_plaintext_hashes(self):
104         d = self._start()
105         def _try(unused=None):
106             hashes = self.plaintext_hashes[:]
107             return hashes
108         d.addCallback(_try)
109         return d
110
111     def get_crypttext_hashes(self):
112         d = self._start()
113         def _try(unused=None):
114             hashes = self.crypttext_hashes[:]
115             if self.mode == "bad crypttext hashroot":
116                 hashes[0] = flip_bit(hashes[0])
117             if self.mode == "bad crypttext hash":
118                 hashes[1] = flip_bit(hashes[1])
119             return hashes
120         d.addCallback(_try)
121         return d
122
123     def get_block_hashes(self, at_least_these=()):
124         d = self._start()
125         def _try(unused=None):
126             if self.mode == "bad blockhash":
127                 hashes = self.block_hashes[:]
128                 hashes[1] = flip_bit(hashes[1])
129                 return hashes
130             return self.block_hashes
131         d.addCallback(_try)
132         return d
133
134     def get_share_hashes(self, at_least_these=()):
135         d = self._start()
136         def _try(unused=None):
137             if self.mode == "bad sharehash":
138                 hashes = self.share_hashes[:]
139                 hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
140                 return hashes
141             if self.mode == "missing sharehash":
142                 # one sneaky attack would be to pretend we don't know our own
143                 # sharehash, which could manage to frame someone else.
144                 # download.py is supposed to guard against this case.
145                 return []
146             return self.share_hashes
147         d.addCallback(_try)
148         return d
149
150     def get_uri_extension(self):
151         d = self._start()
152         def _try(unused=None):
153             if self.mode == "bad uri_extension":
154                 return flip_bit(self.uri_extension)
155             return self.uri_extension
156         d.addCallback(_try)
157         return d
158
159
160 def make_data(length):
161     data = "happy happy joy joy" * 100
162     assert length <= len(data)
163     return data[:length]
164
165 class ValidatedExtendedURIProxy(unittest.TestCase):
166     timeout = 240 # It takes longer than 120 seconds on Francois's arm box.
167     K = 4
168     M = 10
169     SIZE = 200
170     SEGSIZE = 72
171     _TMP = SIZE%SEGSIZE
172     if _TMP == 0:
173         _TMP = SEGSIZE
174     if _TMP % K != 0:
175         _TMP += (K - (_TMP % K))
176     TAIL_SEGSIZE = _TMP
177     _TMP = SIZE / SEGSIZE
178     if SIZE % SEGSIZE != 0:
179         _TMP += 1
180     NUM_SEGMENTS = _TMP
181     mindict = { 'segment_size': SEGSIZE,
182                 'crypttext_root_hash': '0'*hashutil.CRYPTO_VAL_SIZE,
183                 'share_root_hash': '1'*hashutil.CRYPTO_VAL_SIZE }
184     optional_consistent = { 'crypttext_hash': '2'*hashutil.CRYPTO_VAL_SIZE,
185                             'codec_name': "crs",
186                             'codec_params': "%d-%d-%d" % (SEGSIZE, K, M),
187                             'tail_codec_params': "%d-%d-%d" % (TAIL_SEGSIZE, K, M),
188                             'num_segments': NUM_SEGMENTS,
189                             'size': SIZE,
190                             'needed_shares': K,
191                             'total_shares': M,
192                             'plaintext_hash': "anything",
193                             'plaintext_root_hash': "anything", }
194     # optional_inconsistent = { 'crypttext_hash': ('2'*(hashutil.CRYPTO_VAL_SIZE-1), "", 77),
195     optional_inconsistent = { 'crypttext_hash': (77,),
196                               'codec_name': ("digital fountain", ""),
197                               'codec_params': ("%d-%d-%d" % (SEGSIZE, K-1, M),
198                                                "%d-%d-%d" % (SEGSIZE-1, K, M),
199                                                "%d-%d-%d" % (SEGSIZE, K, M-1)),
200                               'tail_codec_params': ("%d-%d-%d" % (TAIL_SEGSIZE, K-1, M),
201                                                "%d-%d-%d" % (TAIL_SEGSIZE-1, K, M),
202                                                "%d-%d-%d" % (TAIL_SEGSIZE, K, M-1)),
203                               'num_segments': (NUM_SEGMENTS-1,),
204                               'size': (SIZE-1,),
205                               'needed_shares': (K-1,),
206                               'total_shares': (M-1,), }
207
208     def _test(self, uebdict):
209         uebstring = uri.pack_extension(uebdict)
210         uebhash = hashutil.uri_extension_hash(uebstring)
211         fb = FakeBucketReaderWriterProxy()
212         fb.put_uri_extension(uebstring)
213         verifycap = uri.CHKFileVerifierURI(storage_index='x'*16, uri_extension_hash=uebhash, needed_shares=self.K, total_shares=self.M, size=self.SIZE)
214         vup = checker.ValidatedExtendedURIProxy(fb, verifycap)
215         return vup.start()
216
217     def _test_accept(self, uebdict):
218         return self._test(uebdict)
219
220     def _should_fail(self, res, expected_failures):
221         if isinstance(res, Failure):
222             res.trap(*expected_failures)
223         else:
224             self.fail("was supposed to raise %s, not get '%s'" % (expected_failures, res))
225
226     def _test_reject(self, uebdict):
227         d = self._test(uebdict)
228         d.addBoth(self._should_fail, (KeyError, checker.BadURIExtension))
229         return d
230
231     def test_accept_minimal(self):
232         return self._test_accept(self.mindict)
233
234     def test_reject_insufficient(self):
235         dl = []
236         for k in self.mindict.iterkeys():
237             insuffdict = self.mindict.copy()
238             del insuffdict[k]
239             d = self._test_reject(insuffdict)
240         dl.append(d)
241         return defer.DeferredList(dl)
242
243     def test_accept_optional(self):
244         dl = []
245         for k in self.optional_consistent.iterkeys():
246             mydict = self.mindict.copy()
247             mydict[k] = self.optional_consistent[k]
248             d = self._test_accept(mydict)
249         dl.append(d)
250         return defer.DeferredList(dl)
251
252     def test_reject_optional(self):
253         dl = []
254         for k in self.optional_inconsistent.iterkeys():
255             for v in self.optional_inconsistent[k]:
256                 mydict = self.mindict.copy()
257                 mydict[k] = v
258                 d = self._test_reject(mydict)
259                 dl.append(d)
260         return defer.DeferredList(dl)
261
262 class Encode(unittest.TestCase):
263     timeout = 2400 # It takes longer than 240 seconds on Zandr's ARM box.
264
265     def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
266                   expected_block_hashes, expected_share_hashes):
267         data = make_data(datalen)
268         # force use of multiple segments
269         e = encode.Encoder()
270         u = upload.Data(data, convergence="some convergence string")
271         u.max_segment_size = max_segment_size
272         u.encoding_param_k = 25
273         u.encoding_param_happy = 75
274         u.encoding_param_n = 100
275         eu = upload.EncryptAnUploadable(u)
276         d = e.set_encrypted_uploadable(eu)
277
278         all_shareholders = []
279         def _ready(res):
280             k,happy,n = e.get_param("share_counts")
281             _assert(n == NUM_SHARES) # else we'll be completely confused
282             numsegs = e.get_param("num_segments")
283             _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS)
284             segsize = e.get_param("segment_size")
285             _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize,
286                      NUM_SEGMENTS, segsize,
287                      (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize)
288
289             shareholders = {}
290             servermap = {}
291             for shnum in range(NUM_SHARES):
292                 peer = FakeBucketReaderWriterProxy()
293                 shareholders[shnum] = peer
294                 servermap.setdefault(shnum, set()).add(peer.get_peerid())
295                 all_shareholders.append(peer)
296             e.set_shareholders(shareholders, servermap)
297             return e.start()
298         d.addCallback(_ready)
299
300         def _check(res):
301             verifycap = res
302             self.failUnless(isinstance(verifycap.uri_extension_hash, str))
303             self.failUnlessEqual(len(verifycap.uri_extension_hash), 32)
304             for i,peer in enumerate(all_shareholders):
305                 self.failUnless(peer.closed)
306                 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
307                 # each peer gets a full tree of block hashes. For 3 or 4
308                 # segments, that's 7 hashes. For 5 segments it's 15 hashes.
309                 self.failUnlessEqual(len(peer.block_hashes),
310                                      expected_block_hashes)
311                 for h in peer.block_hashes:
312                     self.failUnlessEqual(len(h), 32)
313                 # each peer also gets their necessary chain of share hashes.
314                 # For 100 shares (rounded up to 128 leaves), that's 8 hashes
315                 self.failUnlessEqual(len(peer.share_hashes),
316                                      expected_share_hashes)
317                 for (hashnum, h) in peer.share_hashes:
318                     self.failUnless(isinstance(hashnum, int))
319                     self.failUnlessEqual(len(h), 32)
320         d.addCallback(_check)
321
322         return d
323
324     def test_send_74(self):
325         # 3 segments (25, 25, 24)
326         return self.do_encode(25, 74, 100, 3, 7, 8)
327     def test_send_75(self):
328         # 3 segments (25, 25, 25)
329         return self.do_encode(25, 75, 100, 3, 7, 8)
330     def test_send_51(self):
331         # 3 segments (25, 25, 1)
332         return self.do_encode(25, 51, 100, 3, 7, 8)
333
334     def test_send_76(self):
335         # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
336         return self.do_encode(25, 76, 100, 4, 7, 8)
337     def test_send_99(self):
338         # 4 segments: 25,25,25,24
339         return self.do_encode(25, 99, 100, 4, 7, 8)
340     def test_send_100(self):
341         # 4 segments: 25,25,25,25
342         return self.do_encode(25, 100, 100, 4, 7, 8)
343
344     def test_send_124(self):
345         # 5 segments: 25, 25, 25, 25, 24
346         return self.do_encode(25, 124, 100, 5, 15, 8)
347     def test_send_125(self):
348         # 5 segments: 25, 25, 25, 25, 25
349         return self.do_encode(25, 125, 100, 5, 15, 8)
350     def test_send_101(self):
351         # 5 segments: 25, 25, 25, 25, 1
352         return self.do_encode(25, 101, 100, 5, 15, 8)
353
354
355 class Roundtrip(GridTestMixin, unittest.TestCase):
356
357     # a series of 3*3 tests to check out edge conditions. One axis is how the
358     # plaintext is divided into segments: kn+(-1,0,1). Another way to express
359     # this is n%k == -1 or 0 or 1. For example, for 25-byte segments, we
360     # might test 74 bytes, 75 bytes, and 76 bytes.
361
362     # on the other axis is how many leaves in the block hash tree we wind up
363     # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
364     # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
365     # segments, and 5 segments.
366
367     # that results in the following series of data lengths:
368     #  3 segs: 74, 75, 51
369     #  4 segs: 99, 100, 76
370     #  5 segs: 124, 125, 101
371
372     # all tests encode to 100 shares, which means the share hash tree will
373     # have 128 leaves, which means that buckets will be given an 8-long share
374     # hash chain
375
376     # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
377     # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
378     # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
379     # trees, which gets 15 blockhashes.
380
381     def test_74(self): return self.do_test_size(74)
382     def test_75(self): return self.do_test_size(75)
383     def test_51(self): return self.do_test_size(51)
384     def test_99(self): return self.do_test_size(99)
385     def test_100(self): return self.do_test_size(100)
386     def test_76(self): return self.do_test_size(76)
387     def test_124(self): return self.do_test_size(124)
388     def test_125(self): return self.do_test_size(125)
389     def test_101(self): return self.do_test_size(101)
390
391     def upload(self, data):
392         u = upload.Data(data, None)
393         u.max_segment_size = 25
394         u.encoding_param_k = 25
395         u.encoding_param_happy = 1
396         u.encoding_param_n = 100
397         d = self.c0.upload(u)
398         d.addCallback(lambda ur: self.c0.create_node_from_uri(ur.get_uri()))
399         # returns a FileNode
400         return d
401
402     def do_test_size(self, size):
403         self.basedir = self.mktemp()
404         self.set_up_grid()
405         self.c0 = self.g.clients[0]
406         DATA = "p"*size
407         d = self.upload(DATA)
408         d.addCallback(lambda n: download_to_data(n))
409         def _downloaded(newdata):
410             self.failUnlessEqual(newdata, DATA)
411         d.addCallback(_downloaded)
412         return d