]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_mutable: exercise short reads too
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import itertools, struct, re
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer
6 from twisted.python import failure, log
7 from allmydata import mutable, uri, dirnode, download
8 from allmydata.util.hashutil import tagged_hash
9 from allmydata.encode import NotEnoughPeersError
10 from allmydata.interfaces import IURI, INewDirectoryURI, \
11      IMutableFileURI, IUploadable, IFileURI
12 from allmydata.filenode import LiteralFileNode
13 import sha
14
15 #from allmydata.test.common import FakeMutableFileNode
16 #FakeFilenode = FakeMutableFileNode
17
18 class FakeFilenode(mutable.MutableFileNode):
19     counter = itertools.count(1)
20     all_contents = {}
21     all_rw_friends = {}
22
23     def create(self, initial_contents):
24         d = mutable.MutableFileNode.create(self, initial_contents)
25         def _then(res):
26             self.all_contents[self.get_uri()] = initial_contents
27             return res
28         d.addCallback(_then)
29         return d
30     def init_from_uri(self, myuri):
31         mutable.MutableFileNode.init_from_uri(self, myuri)
32         return self
33     def _generate_pubprivkeys(self):
34         count = self.counter.next()
35         return FakePubKey(count), FakePrivKey(count)
36     def _publish(self, initial_contents):
37         self.all_contents[self.get_uri()] = initial_contents
38         return defer.succeed(self)
39
40     def download_to_data(self):
41         if self.is_readonly():
42             assert self.all_rw_friends.has_key(self.get_uri()), (self.get_uri(), id(self.all_rw_friends))
43             return defer.succeed(self.all_contents[self.all_rw_friends[self.get_uri()]])
44         else:
45             return defer.succeed(self.all_contents[self.get_uri()])
46     def replace(self, newdata):
47         self.all_contents[self.get_uri()] = newdata
48         return defer.succeed(None)
49
50 class FakeStorage:
51     # this class replaces the collection of storage servers, allowing the
52     # tests to examine and manipulate the published shares. It also lets us
53     # control the order in which read queries are answered, to exercise more
54     # of the error-handling code in mutable.Retrieve .
55     #
56     # Note that we ignore the storage index: this FakeStorage instance can
57     # only be used for a single storage index.
58
59
60     def __init__(self):
61         self._peers = {}
62
63     def read(self, peerid, storage_index):
64         shares = self._peers.get(peerid, {})
65         return shares
66
67     def write(self, peerid, storage_index, shnum, offset, data):
68         if peerid not in self._peers:
69             self._peers[peerid] = {}
70         shares = self._peers[peerid]
71         f = StringIO()
72         f.write(shares.get(shnum, ""))
73         f.seek(offset)
74         f.write(data)
75         shares[shnum] = f.getvalue()
76
77
78 class FakePublish(mutable.Publish):
79
80     def _do_read(self, ss, peerid, storage_index, shnums, readv):
81         assert ss[0] == peerid
82         assert shnums == []
83         return defer.succeed(self._storage.read(peerid, storage_index))
84
85     def _do_testreadwrite(self, peerid, secrets,
86                           tw_vectors, read_vector):
87         storage_index = self._node._uri.storage_index
88         # always-pass: parrot the test vectors back to them.
89         readv = {}
90         for shnum, (testv, writev, new_length) in tw_vectors.items():
91             for (offset, length, op, specimen) in testv:
92                 assert op in ("le", "eq", "ge")
93             readv[shnum] = [ specimen
94                              for (offset, length, op, specimen)
95                              in testv ]
96             for (offset, data) in writev:
97                 self._storage.write(peerid, storage_index, shnum, offset, data)
98         answer = (True, readv)
99         return defer.succeed(answer)
100
101
102
103
104 class FakeNewDirectoryNode(dirnode.NewDirectoryNode):
105     filenode_class = FakeFilenode
106
107 class FakeClient:
108     def __init__(self, num_peers=10):
109         self._num_peers = num_peers
110         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
111                          for i in range(self._num_peers)]
112         self.nodeid = "fakenodeid"
113
114     def log(self, msg, **kw):
115         return log.msg(msg, **kw)
116
117     def get_renewal_secret(self):
118         return "I hereby permit you to renew my files"
119     def get_cancel_secret(self):
120         return "I hereby permit you to cancel my leases"
121
122     def create_empty_dirnode(self):
123         n = FakeNewDirectoryNode(self)
124         d = n.create()
125         d.addCallback(lambda res: n)
126         return d
127
128     def create_dirnode_from_uri(self, u):
129         return FakeNewDirectoryNode(self).init_from_uri(u)
130
131     def create_mutable_file(self, contents=""):
132         n = FakeFilenode(self)
133         d = n.create(contents)
134         d.addCallback(lambda res: n)
135         return d
136
137     def create_node_from_uri(self, u):
138         u = IURI(u)
139         if INewDirectoryURI.providedBy(u):
140             return self.create_dirnode_from_uri(u)
141         if IFileURI.providedBy(u):
142             if isinstance(u, uri.LiteralFileURI):
143                 return LiteralFileNode(u, self)
144             else:
145                 # CHK
146                 raise RuntimeError("not simulated")
147         assert IMutableFileURI.providedBy(u), u
148         res = FakeFilenode(self).init_from_uri(u)
149         return res
150
151     def get_permuted_peers(self, service_name, key):
152         # TODO: include_myself=True
153         """
154         @return: list of (peerid, connection,)
155         """
156         peers_and_connections = [(pid, (pid,)) for pid in self._peerids]
157         results = []
158         for peerid, connection in peers_and_connections:
159             assert isinstance(peerid, str)
160             permuted = sha.new(key + peerid).digest()
161             results.append((permuted, peerid, connection))
162         results.sort()
163         results = [ (r[1],r[2]) for r in results]
164         return results
165
166     def upload(self, uploadable):
167         assert IUploadable.providedBy(uploadable)
168         d = uploadable.get_size()
169         d.addCallback(lambda length: uploadable.read(length))
170         #d.addCallback(self.create_mutable_file)
171         def _got_data(datav):
172             data = "".join(datav)
173             #newnode = FakeFilenode(self)
174             return uri.LiteralFileURI(data)
175         d.addCallback(_got_data)
176         return d
177
178 class FakePubKey:
179     def __init__(self, count):
180         self.count = count
181     def serialize(self):
182         return "PUBKEY-%d" % self.count
183     def verify(self, msg, signature):
184         if signature[:5] != "SIGN(":
185             return False
186         if signature[5:-1] != msg:
187             return False
188         if signature[-1] != ")":
189             return False
190         return True
191
192 class FakePrivKey:
193     def __init__(self, count):
194         self.count = count
195     def serialize(self):
196         return "PRIVKEY-%d" % self.count
197     def sign(self, data):
198         return "SIGN(%s)" % data
199
200
201 class Filenode(unittest.TestCase):
202     def setUp(self):
203         self.client = FakeClient()
204
205     def test_create(self):
206         d = self.client.create_mutable_file()
207         def _created(n):
208             d = n.replace("contents 1")
209             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
210             d.addCallback(lambda res: n.download_to_data())
211             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
212             d.addCallback(lambda res: n.replace("contents 2"))
213             d.addCallback(lambda res: n.download_to_data())
214             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
215             d.addCallback(lambda res: n.download(download.Data()))
216             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
217             return d
218         d.addCallback(_created)
219         return d
220
221     def test_create_with_initial_contents(self):
222         d = self.client.create_mutable_file("contents 1")
223         def _created(n):
224             d = n.download_to_data()
225             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
226             d.addCallback(lambda res: n.replace("contents 2"))
227             d.addCallback(lambda res: n.download_to_data())
228             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
229             return d
230         d.addCallback(_created)
231         return d
232
233
234 class Publish(unittest.TestCase):
235     def test_encrypt(self):
236         c = FakeClient()
237         fn = FakeFilenode(c)
238         # .create usually returns a Deferred, but we happen to know it's
239         # synchronous
240         CONTENTS = "some initial contents"
241         fn.create(CONTENTS)
242         p = mutable.Publish(fn)
243         target_info = None
244         d = defer.maybeDeferred(p._encrypt_and_encode, target_info,
245                                 CONTENTS, "READKEY", "IV"*8, 3, 10)
246         def _done( ((shares, share_ids),
247                     required_shares, total_shares,
248                     segsize, data_length, target_info2) ):
249             self.failUnlessEqual(len(shares), 10)
250             for sh in shares:
251                 self.failUnless(isinstance(sh, str))
252                 self.failUnlessEqual(len(sh), 7)
253             self.failUnlessEqual(len(share_ids), 10)
254             self.failUnlessEqual(required_shares, 3)
255             self.failUnlessEqual(total_shares, 10)
256             self.failUnlessEqual(segsize, 21)
257             self.failUnlessEqual(data_length, len(CONTENTS))
258             self.failUnlessIdentical(target_info, target_info2)
259         d.addCallback(_done)
260         return d
261
262     def test_generate(self):
263         c = FakeClient()
264         fn = FakeFilenode(c)
265         # .create usually returns a Deferred, but we happen to know it's
266         # synchronous
267         CONTENTS = "some initial contents"
268         fn.create(CONTENTS)
269         p = mutable.Publish(fn)
270         r = mutable.Retrieve(fn)
271         # make some fake shares
272         shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
273         target_info = None
274         p._privkey = FakePrivKey(0)
275         p._encprivkey = "encprivkey"
276         p._pubkey = FakePubKey(0)
277         d = defer.maybeDeferred(p._generate_shares,
278                                 (shares_and_ids,
279                                  3, 10,
280                                  21, # segsize
281                                  len(CONTENTS),
282                                  target_info),
283                                 3, # seqnum
284                                 "IV"*8)
285         def _done( (seqnum, root_hash, final_shares, target_info2) ):
286             self.failUnlessEqual(seqnum, 3)
287             self.failUnlessEqual(len(root_hash), 32)
288             self.failUnless(isinstance(final_shares, dict))
289             self.failUnlessEqual(len(final_shares), 10)
290             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
291             for i,sh in final_shares.items():
292                 self.failUnless(isinstance(sh, str))
293                 self.failUnlessEqual(len(sh), 381)
294                 # feed the share through the unpacker as a sanity-check
295                 pieces = mutable.unpack_share(sh)
296                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
297                  pubkey, signature, share_hash_chain, block_hash_tree,
298                  share_data, enc_privkey) = pieces
299                 self.failUnlessEqual(u_seqnum, 3)
300                 self.failUnlessEqual(u_root_hash, root_hash)
301                 self.failUnlessEqual(k, 3)
302                 self.failUnlessEqual(N, 10)
303                 self.failUnlessEqual(segsize, 21)
304                 self.failUnlessEqual(datalen, len(CONTENTS))
305                 self.failUnlessEqual(pubkey, FakePubKey(0).serialize())
306                 sig_material = struct.pack(">BQ32s16s BBQQ",
307                                            0, seqnum, root_hash, IV,
308                                            k, N, segsize, datalen)
309                 self.failUnlessEqual(signature,
310                                      FakePrivKey(0).sign(sig_material))
311                 self.failUnless(isinstance(share_hash_chain, dict))
312                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
313                 for shnum,share_hash in share_hash_chain.items():
314                     self.failUnless(isinstance(shnum, int))
315                     self.failUnless(isinstance(share_hash, str))
316                     self.failUnlessEqual(len(share_hash), 32)
317                 self.failUnless(isinstance(block_hash_tree, list))
318                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
319                 self.failUnlessEqual(IV, "IV"*8)
320                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
321                 self.failUnlessEqual(enc_privkey, "encprivkey")
322             self.failUnlessIdentical(target_info, target_info2)
323         d.addCallback(_done)
324         return d
325
326     def setup_for_sharemap(self, num_peers):
327         c = FakeClient(num_peers)
328         fn = FakeFilenode(c)
329         s = FakeStorage()
330         # .create usually returns a Deferred, but we happen to know it's
331         # synchronous
332         CONTENTS = "some initial contents"
333         fn.create(CONTENTS)
334         p = FakePublish(fn)
335         p._storage_index = "\x00"*32
336         p._new_seqnum = 3
337         p._read_size = 1000
338         #r = mutable.Retrieve(fn)
339         p._storage = s
340         return c, p
341
342     def shouldFail(self, expected_failure, which, call, *args, **kwargs):
343         substring = kwargs.pop("substring", None)
344         d = defer.maybeDeferred(call, *args, **kwargs)
345         def _done(res):
346             if isinstance(res, failure.Failure):
347                 res.trap(expected_failure)
348                 if substring:
349                     self.failUnless(substring in str(res),
350                                     "substring '%s' not in '%s'"
351                                     % (substring, str(res)))
352             else:
353                 self.fail("%s was supposed to raise %s, not get '%s'" %
354                           (which, expected_failure, res))
355         d.addBoth(_done)
356         return d
357
358     def test_sharemap_20newpeers(self):
359         c, p = self.setup_for_sharemap(20)
360
361         total_shares = 10
362         d = p._query_peers(total_shares)
363         def _done(target_info):
364             (target_map, shares_per_peer) = target_info
365             shares_per_peer = {}
366             for shnum in target_map:
367                 for (peerid, old_seqnum, old_R) in target_map[shnum]:
368                     #print "shnum[%d]: send to %s [oldseqnum=%s]" % \
369                     #      (shnum, idlib.b2a(peerid), old_seqnum)
370                     if peerid not in shares_per_peer:
371                         shares_per_peer[peerid] = 1
372                     else:
373                         shares_per_peer[peerid] += 1
374             # verify that we're sending only one share per peer
375             for peerid, count in shares_per_peer.items():
376                 self.failUnlessEqual(count, 1)
377         d.addCallback(_done)
378         return d
379
380     def test_sharemap_3newpeers(self):
381         c, p = self.setup_for_sharemap(3)
382
383         total_shares = 10
384         d = p._query_peers(total_shares)
385         def _done(target_info):
386             (target_map, shares_per_peer) = target_info
387             shares_per_peer = {}
388             for shnum in target_map:
389                 for (peerid, old_seqnum, old_R) in target_map[shnum]:
390                     if peerid not in shares_per_peer:
391                         shares_per_peer[peerid] = 1
392                     else:
393                         shares_per_peer[peerid] += 1
394             # verify that we're sending 3 or 4 shares per peer
395             for peerid, count in shares_per_peer.items():
396                 self.failUnless(count in (3,4), count)
397         d.addCallback(_done)
398         return d
399
400     def test_sharemap_nopeers(self):
401         c, p = self.setup_for_sharemap(0)
402
403         total_shares = 10
404         d = self.shouldFail(NotEnoughPeersError, "test_sharemap_nopeers",
405                             p._query_peers, total_shares)
406         return d
407
408     def test_write(self):
409         total_shares = 10
410         c, p = self.setup_for_sharemap(20)
411         p._privkey = FakePrivKey(0)
412         p._encprivkey = "encprivkey"
413         p._pubkey = FakePubKey(0)
414         # make some fake shares
415         CONTENTS = "some initial contents"
416         shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
417         d = defer.maybeDeferred(p._query_peers, total_shares)
418         IV = "IV"*8
419         d.addCallback(lambda target_info:
420                       p._generate_shares( (shares_and_ids,
421                                            3, total_shares,
422                                            21, # segsize
423                                            len(CONTENTS),
424                                            target_info),
425                                           3, # seqnum
426                                           IV))
427         d.addCallback(p._send_shares, IV)
428         def _done((surprised, dispatch_map)):
429             self.failIf(surprised, "surprised!")
430         d.addCallback(_done)
431         return d
432
433 class FakeRetrieve(mutable.Retrieve):
434     def _do_read(self, ss, peerid, storage_index, shnums, readv):
435         shares = self._storage.read(peerid, storage_index)
436
437         response = {}
438         for shnum in shares:
439             if shnums and shnum not in shnums:
440                 continue
441             vector = response[shnum] = []
442             for (offset, length) in readv:
443                 vector.append(shares[shnum][offset:offset+length])
444         return defer.succeed(response)
445
446     def _deserialize_pubkey(self, pubkey_s):
447         mo = re.search(r"^PUBKEY-(\d+)$", pubkey_s)
448         if not mo:
449             raise RuntimeError("mangled pubkey")
450         count = mo.group(1)
451         return FakePubKey(int(count))
452
453
454 class Roundtrip(unittest.TestCase):
455
456     def setup_for_publish(self, num_peers):
457         c = FakeClient(num_peers)
458         fn = FakeFilenode(c)
459         s = FakeStorage()
460         # .create usually returns a Deferred, but we happen to know it's
461         # synchronous
462         fn.create("")
463         p = FakePublish(fn)
464         p._storage = s
465         r = FakeRetrieve(fn)
466         r._storage = s
467         return c, s, fn, p, r
468
469     def test_basic(self):
470         c, s, fn, p, r = self.setup_for_publish(20)
471         contents = "New contents go here"
472         d = p.publish(contents)
473         def _published(res):
474             return r.retrieve()
475         d.addCallback(_published)
476         def _retrieved(new_contents):
477             self.failUnlessEqual(contents, new_contents)
478         d.addCallback(_retrieved)
479         return d
480
481     def flip_bit(self, original, byte_offset):
482         return (original[:byte_offset] +
483                 chr(ord(original[byte_offset]) ^ 0x01) +
484                 original[byte_offset+1:])
485
486
487     def shouldFail(self, expected_failure, which, substring,
488                     callable, *args, **kwargs):
489         assert substring is None or isinstance(substring, str)
490         d = defer.maybeDeferred(callable, *args, **kwargs)
491         def done(res):
492             if isinstance(res, failure.Failure):
493                 res.trap(expected_failure)
494                 if substring:
495                     self.failUnless(substring in str(res),
496                                     "substring '%s' not in '%s'"
497                                     % (substring, str(res)))
498             else:
499                 self.fail("%s was supposed to raise %s, not get '%s'" %
500                           (which, expected_failure, res))
501         d.addBoth(done)
502         return d
503
504     def _corrupt_all(self, offset, substring, refetch_pubkey=False,
505                      should_succeed=False):
506         c, s, fn, p, r = self.setup_for_publish(20)
507         contents = "New contents go here"
508         d = p.publish(contents)
509         def _published(res):
510             if refetch_pubkey:
511                 # clear the pubkey, to force a fetch
512                 r._pubkey = None
513             for peerid in s._peers:
514                 shares = s._peers[peerid]
515                 for shnum in shares:
516                     data = shares[shnum]
517                     (version,
518                      seqnum,
519                      root_hash,
520                      IV,
521                      k, N, segsize, datalen,
522                      o) = mutable.unpack_header(data)
523                     if isinstance(offset, tuple):
524                         offset1, offset2 = offset
525                     else:
526                         offset1 = offset
527                         offset2 = 0
528                     if offset1 == "pubkey":
529                         real_offset = 107
530                     elif offset1 in o:
531                         real_offset = o[offset1]
532                     else:
533                         real_offset = offset1
534                     real_offset = int(real_offset) + offset2
535                     assert isinstance(real_offset, int), offset
536                     shares[shnum] = self.flip_bit(data, real_offset)
537         d.addCallback(_published)
538         if should_succeed:
539             d.addCallback(lambda res: r.retrieve())
540         else:
541             d.addCallback(lambda res:
542                           self.shouldFail(NotEnoughPeersError,
543                                           "_corrupt_all(offset=%s)" % (offset,),
544                                           substring,
545                                           r.retrieve))
546         return d
547
548     def test_corrupt_all_verbyte(self):
549         # when the version byte is not 0, we hit an assertion error in
550         # unpack_share().
551         return self._corrupt_all(0, "AssertionError")
552
553     def test_corrupt_all_seqnum(self):
554         # a corrupt sequence number will trigger a bad signature
555         return self._corrupt_all(1, "signature is invalid")
556
557     def test_corrupt_all_R(self):
558         # a corrupt root hash will trigger a bad signature
559         return self._corrupt_all(9, "signature is invalid")
560
561     def test_corrupt_all_IV(self):
562         # a corrupt salt/IV will trigger a bad signature
563         return self._corrupt_all(41, "signature is invalid")
564
565     def test_corrupt_all_k(self):
566         # a corrupt 'k' will trigger a bad signature
567         return self._corrupt_all(57, "signature is invalid")
568
569     def test_corrupt_all_N(self):
570         # a corrupt 'N' will trigger a bad signature
571         return self._corrupt_all(58, "signature is invalid")
572
573     def test_corrupt_all_segsize(self):
574         # a corrupt segsize will trigger a bad signature
575         return self._corrupt_all(59, "signature is invalid")
576
577     def test_corrupt_all_datalen(self):
578         # a corrupt data length will trigger a bad signature
579         return self._corrupt_all(67, "signature is invalid")
580
581     def test_corrupt_all_pubkey(self):
582         # a corrupt pubkey won't match the URI's fingerprint
583         return self._corrupt_all("pubkey", "pubkey doesn't match fingerprint",
584                                  refetch_pubkey=True)
585
586     def test_corrupt_all_sig(self):
587         # a corrupt signature is a bad one
588         # the signature runs from about [543:799], depending upon the length
589         # of the pubkey
590         return self._corrupt_all("signature", "signature is invalid",
591                                  refetch_pubkey=True)
592
593     def test_corrupt_all_share_hash_chain_number(self):
594         # a corrupt share hash chain entry will show up as a bad hash. If we
595         # mangle the first byte, that will look like a bad hash number,
596         # causing an IndexError
597         return self._corrupt_all("share_hash_chain", "corrupt hashes")
598
599     def test_corrupt_all_share_hash_chain_hash(self):
600         # a corrupt share hash chain entry will show up as a bad hash. If we
601         # mangle a few bytes in, that will look like a bad hash.
602         return self._corrupt_all(("share_hash_chain",4), "corrupt hashes")
603
604     def test_corrupt_all_block_hash_tree(self):
605         return self._corrupt_all("block_hash_tree", "block hash tree failure")
606
607     def test_corrupt_all_block(self):
608         return self._corrupt_all("share_data", "block hash tree failure")
609
610     def test_corrupt_all_encprivkey(self):
611         # a corrupted privkey won't even be noticed by the reader
612         return self._corrupt_all("enc_privkey", None, should_succeed=True)
613
614     def test_short_read(self):
615         c, s, fn, p, r = self.setup_for_publish(20)
616         contents = "New contents go here"
617         d = p.publish(contents)
618         def _published(res):
619             # force a short read, to make Retrieve._got_results re-send the
620             # queries. But don't make it so short that we can't read the
621             # header.
622             r._read_size = mutable.HEADER_LENGTH + 10
623             return r.retrieve()
624         d.addCallback(_published)
625         def _retrieved(new_contents):
626             self.failUnlessEqual(contents, new_contents)
627         d.addCallback(_retrieved)
628         return d