]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_mutable.py: remove spurious Retrieve during a publish test
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import itertools, struct, re
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import mutable, uri, dirnode, download
8 from allmydata.util.idlib import shortnodeid_b2a
9 from allmydata.util.hashutil import tagged_hash
10 from allmydata.encode import NotEnoughPeersError
11 from allmydata.interfaces import IURI, INewDirectoryURI, \
12      IMutableFileURI, IUploadable, IFileURI
13 from allmydata.filenode import LiteralFileNode
14 from foolscap.eventual import eventually
15 from foolscap.logging import log
16 import sha
17
18 #from allmydata.test.common import FakeMutableFileNode
19 #FakeFilenode = FakeMutableFileNode
20
21 class FakeFilenode(mutable.MutableFileNode):
22     counter = itertools.count(1)
23     all_contents = {}
24     all_rw_friends = {}
25
26     def create(self, initial_contents):
27         d = mutable.MutableFileNode.create(self, initial_contents)
28         def _then(res):
29             self.all_contents[self.get_uri()] = initial_contents
30             return res
31         d.addCallback(_then)
32         return d
33     def init_from_uri(self, myuri):
34         mutable.MutableFileNode.init_from_uri(self, myuri)
35         return self
36     def _generate_pubprivkeys(self, key_size):
37         count = self.counter.next()
38         return FakePubKey(count), FakePrivKey(count)
39     def _publish(self, initial_contents):
40         self.all_contents[self.get_uri()] = initial_contents
41         return defer.succeed(self)
42
43     def download_to_data(self):
44         if self.is_readonly():
45             assert self.all_rw_friends.has_key(self.get_uri()), (self.get_uri(), id(self.all_rw_friends))
46             return defer.succeed(self.all_contents[self.all_rw_friends[self.get_uri()]])
47         else:
48             return defer.succeed(self.all_contents[self.get_uri()])
49     def update(self, newdata):
50         self.all_contents[self.get_uri()] = newdata
51         return defer.succeed(None)
52     def overwrite(self, newdata):
53         return self.update(newdata)
54
55 class FakeStorage:
56     # this class replaces the collection of storage servers, allowing the
57     # tests to examine and manipulate the published shares. It also lets us
58     # control the order in which read queries are answered, to exercise more
59     # of the error-handling code in mutable.Retrieve .
60     #
61     # Note that we ignore the storage index: this FakeStorage instance can
62     # only be used for a single storage index.
63
64
65     def __init__(self):
66         self._peers = {}
67         # _sequence is used to cause the responses to occur in a specific
68         # order. If it is in use, then we will defer queries instead of
69         # answering them right away, accumulating the Deferreds in a dict. We
70         # don't know exactly how many queries we'll get, so exactly one
71         # second after the first query arrives, we will release them all (in
72         # order).
73         self._sequence = None
74         self._pending = {}
75
76     def read(self, peerid, storage_index):
77         shares = self._peers.get(peerid, {})
78         if self._sequence is None:
79             return shares
80         d = defer.Deferred()
81         if not self._pending:
82             reactor.callLater(1.0, self._fire_readers)
83         self._pending[peerid] = (d, shares)
84         return d
85
86     def _fire_readers(self):
87         pending = self._pending
88         self._pending = {}
89         extra = []
90         for peerid in self._sequence:
91             if peerid in pending:
92                 d, shares = pending.pop(peerid)
93                 eventually(d.callback, shares)
94         for (d, shares) in pending.values():
95             eventually(d.callback, shares)
96
97     def write(self, peerid, storage_index, shnum, offset, data):
98         if peerid not in self._peers:
99             self._peers[peerid] = {}
100         shares = self._peers[peerid]
101         f = StringIO()
102         f.write(shares.get(shnum, ""))
103         f.seek(offset)
104         f.write(data)
105         shares[shnum] = f.getvalue()
106
107
108 class FakePublish(mutable.Publish):
109
110     def _do_read(self, ss, peerid, storage_index, shnums, readv):
111         assert ss[0] == peerid
112         assert shnums == []
113         return defer.maybeDeferred(self._storage.read, peerid, storage_index)
114
115     def _do_testreadwrite(self, peerid, secrets,
116                           tw_vectors, read_vector):
117         storage_index = self._node._uri.storage_index
118         # always-pass: parrot the test vectors back to them.
119         readv = {}
120         for shnum, (testv, writev, new_length) in tw_vectors.items():
121             for (offset, length, op, specimen) in testv:
122                 assert op in ("le", "eq", "ge")
123             readv[shnum] = [ specimen
124                              for (offset, length, op, specimen)
125                              in testv ]
126             for (offset, data) in writev:
127                 self._storage.write(peerid, storage_index, shnum, offset, data)
128         answer = (True, readv)
129         return defer.succeed(answer)
130
131
132
133
134 class FakeNewDirectoryNode(dirnode.NewDirectoryNode):
135     filenode_class = FakeFilenode
136
137 class FakeClient:
138     def __init__(self, num_peers=10):
139         self._num_peers = num_peers
140         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
141                          for i in range(self._num_peers)]
142         self.nodeid = "fakenodeid"
143
144     def log(self, msg, **kw):
145         return log.msg(msg, **kw)
146
147     def get_renewal_secret(self):
148         return "I hereby permit you to renew my files"
149     def get_cancel_secret(self):
150         return "I hereby permit you to cancel my leases"
151
152     def create_empty_dirnode(self):
153         n = FakeNewDirectoryNode(self)
154         d = n.create()
155         d.addCallback(lambda res: n)
156         return d
157
158     def create_dirnode_from_uri(self, u):
159         return FakeNewDirectoryNode(self).init_from_uri(u)
160
161     def create_mutable_file(self, contents=""):
162         n = FakeFilenode(self)
163         d = n.create(contents)
164         d.addCallback(lambda res: n)
165         return d
166
167     def notify_retrieve(self, r):
168         pass
169
170     def create_node_from_uri(self, u):
171         u = IURI(u)
172         if INewDirectoryURI.providedBy(u):
173             return self.create_dirnode_from_uri(u)
174         if IFileURI.providedBy(u):
175             if isinstance(u, uri.LiteralFileURI):
176                 return LiteralFileNode(u, self)
177             else:
178                 # CHK
179                 raise RuntimeError("not simulated")
180         assert IMutableFileURI.providedBy(u), u
181         res = FakeFilenode(self).init_from_uri(u)
182         return res
183
184     def get_permuted_peers(self, service_name, key):
185         # TODO: include_myself=True
186         """
187         @return: list of (peerid, connection,)
188         """
189         peers_and_connections = [(pid, (pid,)) for pid in self._peerids]
190         results = []
191         for peerid, connection in peers_and_connections:
192             assert isinstance(peerid, str)
193             permuted = sha.new(key + peerid).digest()
194             results.append((permuted, peerid, connection))
195         results.sort()
196         results = [ (r[1],r[2]) for r in results]
197         return results
198
199     def upload(self, uploadable):
200         assert IUploadable.providedBy(uploadable)
201         d = uploadable.get_size()
202         d.addCallback(lambda length: uploadable.read(length))
203         #d.addCallback(self.create_mutable_file)
204         def _got_data(datav):
205             data = "".join(datav)
206             #newnode = FakeFilenode(self)
207             return uri.LiteralFileURI(data)
208         d.addCallback(_got_data)
209         return d
210
211 class FakePubKey:
212     def __init__(self, count):
213         self.count = count
214     def serialize(self):
215         return "PUBKEY-%d" % self.count
216     def verify(self, msg, signature):
217         if signature[:5] != "SIGN(":
218             return False
219         if signature[5:-1] != msg:
220             return False
221         if signature[-1] != ")":
222             return False
223         return True
224
225 class FakePrivKey:
226     def __init__(self, count):
227         self.count = count
228     def serialize(self):
229         return "PRIVKEY-%d" % self.count
230     def sign(self, data):
231         return "SIGN(%s)" % data
232
233
234 class Filenode(unittest.TestCase):
235     def setUp(self):
236         self.client = FakeClient()
237
238     def test_create(self):
239         d = self.client.create_mutable_file()
240         def _created(n):
241             d = n.overwrite("contents 1")
242             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
243             d.addCallback(lambda res: n.download_to_data())
244             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
245             d.addCallback(lambda res: n.overwrite("contents 2"))
246             d.addCallback(lambda res: n.download_to_data())
247             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
248             d.addCallback(lambda res: n.download(download.Data()))
249             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
250             d.addCallback(lambda res: n.update("contents 3"))
251             d.addCallback(lambda res: n.download_to_data())
252             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
253             return d
254         d.addCallback(_created)
255         return d
256
257     def test_create_with_initial_contents(self):
258         d = self.client.create_mutable_file("contents 1")
259         def _created(n):
260             d = n.download_to_data()
261             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
262             d.addCallback(lambda res: n.overwrite("contents 2"))
263             d.addCallback(lambda res: n.download_to_data())
264             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
265             return d
266         d.addCallback(_created)
267         return d
268
269
270 class Publish(unittest.TestCase):
271     def test_encrypt(self):
272         c = FakeClient()
273         fn = FakeFilenode(c)
274         # .create usually returns a Deferred, but we happen to know it's
275         # synchronous
276         CONTENTS = "some initial contents"
277         fn.create(CONTENTS)
278         p = mutable.Publish(fn)
279         target_info = None
280         d = defer.maybeDeferred(p._encrypt_and_encode, target_info,
281                                 CONTENTS, "READKEY", "IV"*8, 3, 10)
282         def _done( ((shares, share_ids),
283                     required_shares, total_shares,
284                     segsize, data_length, target_info2) ):
285             self.failUnlessEqual(len(shares), 10)
286             for sh in shares:
287                 self.failUnless(isinstance(sh, str))
288                 self.failUnlessEqual(len(sh), 7)
289             self.failUnlessEqual(len(share_ids), 10)
290             self.failUnlessEqual(required_shares, 3)
291             self.failUnlessEqual(total_shares, 10)
292             self.failUnlessEqual(segsize, 21)
293             self.failUnlessEqual(data_length, len(CONTENTS))
294             self.failUnlessIdentical(target_info, target_info2)
295         d.addCallback(_done)
296         return d
297
298     def test_generate(self):
299         c = FakeClient()
300         fn = FakeFilenode(c)
301         # .create usually returns a Deferred, but we happen to know it's
302         # synchronous
303         CONTENTS = "some initial contents"
304         fn.create(CONTENTS)
305         p = mutable.Publish(fn)
306         # make some fake shares
307         shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
308         target_info = None
309         p._privkey = FakePrivKey(0)
310         p._encprivkey = "encprivkey"
311         p._pubkey = FakePubKey(0)
312         d = defer.maybeDeferred(p._generate_shares,
313                                 (shares_and_ids,
314                                  3, 10,
315                                  21, # segsize
316                                  len(CONTENTS),
317                                  target_info),
318                                 3, # seqnum
319                                 "IV"*8)
320         def _done( (seqnum, root_hash, final_shares, target_info2) ):
321             self.failUnlessEqual(seqnum, 3)
322             self.failUnlessEqual(len(root_hash), 32)
323             self.failUnless(isinstance(final_shares, dict))
324             self.failUnlessEqual(len(final_shares), 10)
325             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
326             for i,sh in final_shares.items():
327                 self.failUnless(isinstance(sh, str))
328                 self.failUnlessEqual(len(sh), 381)
329                 # feed the share through the unpacker as a sanity-check
330                 pieces = mutable.unpack_share(sh)
331                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
332                  pubkey, signature, share_hash_chain, block_hash_tree,
333                  share_data, enc_privkey) = pieces
334                 self.failUnlessEqual(u_seqnum, 3)
335                 self.failUnlessEqual(u_root_hash, root_hash)
336                 self.failUnlessEqual(k, 3)
337                 self.failUnlessEqual(N, 10)
338                 self.failUnlessEqual(segsize, 21)
339                 self.failUnlessEqual(datalen, len(CONTENTS))
340                 self.failUnlessEqual(pubkey, FakePubKey(0).serialize())
341                 sig_material = struct.pack(">BQ32s16s BBQQ",
342                                            0, seqnum, root_hash, IV,
343                                            k, N, segsize, datalen)
344                 self.failUnlessEqual(signature,
345                                      FakePrivKey(0).sign(sig_material))
346                 self.failUnless(isinstance(share_hash_chain, dict))
347                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
348                 for shnum,share_hash in share_hash_chain.items():
349                     self.failUnless(isinstance(shnum, int))
350                     self.failUnless(isinstance(share_hash, str))
351                     self.failUnlessEqual(len(share_hash), 32)
352                 self.failUnless(isinstance(block_hash_tree, list))
353                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
354                 self.failUnlessEqual(IV, "IV"*8)
355                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
356                 self.failUnlessEqual(enc_privkey, "encprivkey")
357             self.failUnlessIdentical(target_info, target_info2)
358         d.addCallback(_done)
359         return d
360
361     def setup_for_sharemap(self, num_peers):
362         c = FakeClient(num_peers)
363         fn = FakeFilenode(c)
364         s = FakeStorage()
365         # .create usually returns a Deferred, but we happen to know it's
366         # synchronous
367         CONTENTS = "some initial contents"
368         fn.create(CONTENTS)
369         p = FakePublish(fn)
370         p._storage_index = "\x00"*32
371         p._new_seqnum = 3
372         p._read_size = 1000
373         #r = mutable.Retrieve(fn)
374         p._storage = s
375         return c, p
376
377     def shouldFail(self, expected_failure, which, call, *args, **kwargs):
378         substring = kwargs.pop("substring", None)
379         d = defer.maybeDeferred(call, *args, **kwargs)
380         def _done(res):
381             if isinstance(res, failure.Failure):
382                 res.trap(expected_failure)
383                 if substring:
384                     self.failUnless(substring in str(res),
385                                     "substring '%s' not in '%s'"
386                                     % (substring, str(res)))
387             else:
388                 self.fail("%s was supposed to raise %s, not get '%s'" %
389                           (which, expected_failure, res))
390         d.addBoth(_done)
391         return d
392
393     def test_sharemap_20newpeers(self):
394         c, p = self.setup_for_sharemap(20)
395
396         total_shares = 10
397         d = p._query_peers(total_shares)
398         def _done(target_info):
399             (target_map, shares_per_peer) = target_info
400             shares_per_peer = {}
401             for shnum in target_map:
402                 for (peerid, old_seqnum, old_R) in target_map[shnum]:
403                     #print "shnum[%d]: send to %s [oldseqnum=%s]" % \
404                     #      (shnum, idlib.b2a(peerid), old_seqnum)
405                     if peerid not in shares_per_peer:
406                         shares_per_peer[peerid] = 1
407                     else:
408                         shares_per_peer[peerid] += 1
409             # verify that we're sending only one share per peer
410             for peerid, count in shares_per_peer.items():
411                 self.failUnlessEqual(count, 1)
412         d.addCallback(_done)
413         return d
414
415     def test_sharemap_3newpeers(self):
416         c, p = self.setup_for_sharemap(3)
417
418         total_shares = 10
419         d = p._query_peers(total_shares)
420         def _done(target_info):
421             (target_map, shares_per_peer) = target_info
422             shares_per_peer = {}
423             for shnum in target_map:
424                 for (peerid, old_seqnum, old_R) in target_map[shnum]:
425                     if peerid not in shares_per_peer:
426                         shares_per_peer[peerid] = 1
427                     else:
428                         shares_per_peer[peerid] += 1
429             # verify that we're sending 3 or 4 shares per peer
430             for peerid, count in shares_per_peer.items():
431                 self.failUnless(count in (3,4), count)
432         d.addCallback(_done)
433         return d
434
435     def test_sharemap_nopeers(self):
436         c, p = self.setup_for_sharemap(0)
437
438         total_shares = 10
439         d = self.shouldFail(NotEnoughPeersError, "test_sharemap_nopeers",
440                             p._query_peers, total_shares)
441         return d
442
443     def test_write(self):
444         total_shares = 10
445         c, p = self.setup_for_sharemap(20)
446         p._privkey = FakePrivKey(0)
447         p._encprivkey = "encprivkey"
448         p._pubkey = FakePubKey(0)
449         # make some fake shares
450         CONTENTS = "some initial contents"
451         shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
452         d = defer.maybeDeferred(p._query_peers, total_shares)
453         IV = "IV"*8
454         d.addCallback(lambda target_info:
455                       p._generate_shares( (shares_and_ids,
456                                            3, total_shares,
457                                            21, # segsize
458                                            len(CONTENTS),
459                                            target_info),
460                                           3, # seqnum
461                                           IV))
462         d.addCallback(p._send_shares, IV)
463         def _done((surprised, dispatch_map)):
464             self.failIf(surprised, "surprised!")
465         d.addCallback(_done)
466         return d
467
468 class FakeRetrieve(mutable.Retrieve):
469     def _do_read(self, ss, peerid, storage_index, shnums, readv):
470         d = defer.maybeDeferred(self._storage.read, peerid, storage_index)
471         def _read(shares):
472             response = {}
473             for shnum in shares:
474                 if shnums and shnum not in shnums:
475                     continue
476                 vector = response[shnum] = []
477                 for (offset, length) in readv:
478                     vector.append(shares[shnum][offset:offset+length])
479             return response
480         d.addCallback(_read)
481         return d
482
483     def _deserialize_pubkey(self, pubkey_s):
484         mo = re.search(r"^PUBKEY-(\d+)$", pubkey_s)
485         if not mo:
486             raise RuntimeError("mangled pubkey")
487         count = mo.group(1)
488         return FakePubKey(int(count))
489
490
491 class Roundtrip(unittest.TestCase):
492
493     def setup_for_publish(self, num_peers):
494         c = FakeClient(num_peers)
495         fn = FakeFilenode(c)
496         s = FakeStorage()
497         # .create usually returns a Deferred, but we happen to know it's
498         # synchronous
499         fn.create("")
500         p = FakePublish(fn)
501         p._storage = s
502         r = FakeRetrieve(fn)
503         r._storage = s
504         return c, s, fn, p, r
505
506     def test_basic(self):
507         c, s, fn, p, r = self.setup_for_publish(20)
508         contents = "New contents go here"
509         d = p.publish(contents)
510         def _published(res):
511             return r.retrieve()
512         d.addCallback(_published)
513         def _retrieved(new_contents):
514             self.failUnlessEqual(contents, new_contents)
515         d.addCallback(_retrieved)
516         return d
517
518     def flip_bit(self, original, byte_offset):
519         return (original[:byte_offset] +
520                 chr(ord(original[byte_offset]) ^ 0x01) +
521                 original[byte_offset+1:])
522
523
524     def shouldFail(self, expected_failure, which, substring,
525                     callable, *args, **kwargs):
526         assert substring is None or isinstance(substring, str)
527         d = defer.maybeDeferred(callable, *args, **kwargs)
528         def done(res):
529             if isinstance(res, failure.Failure):
530                 res.trap(expected_failure)
531                 if substring:
532                     self.failUnless(substring in str(res),
533                                     "substring '%s' not in '%s'"
534                                     % (substring, str(res)))
535             else:
536                 self.fail("%s was supposed to raise %s, not get '%s'" %
537                           (which, expected_failure, res))
538         d.addBoth(done)
539         return d
540
541     def _corrupt_all(self, offset, substring, refetch_pubkey=False,
542                      should_succeed=False):
543         c, s, fn, p, r = self.setup_for_publish(20)
544         contents = "New contents go here"
545         d = p.publish(contents)
546         def _published(res):
547             if refetch_pubkey:
548                 # clear the pubkey, to force a fetch
549                 r._pubkey = None
550             for peerid in s._peers:
551                 shares = s._peers[peerid]
552                 for shnum in shares:
553                     data = shares[shnum]
554                     (version,
555                      seqnum,
556                      root_hash,
557                      IV,
558                      k, N, segsize, datalen,
559                      o) = mutable.unpack_header(data)
560                     if isinstance(offset, tuple):
561                         offset1, offset2 = offset
562                     else:
563                         offset1 = offset
564                         offset2 = 0
565                     if offset1 == "pubkey":
566                         real_offset = 107
567                     elif offset1 in o:
568                         real_offset = o[offset1]
569                     else:
570                         real_offset = offset1
571                     real_offset = int(real_offset) + offset2
572                     assert isinstance(real_offset, int), offset
573                     shares[shnum] = self.flip_bit(data, real_offset)
574         d.addCallback(_published)
575         if should_succeed:
576             d.addCallback(lambda res: r.retrieve())
577         else:
578             d.addCallback(lambda res:
579                           self.shouldFail(NotEnoughPeersError,
580                                           "_corrupt_all(offset=%s)" % (offset,),
581                                           substring,
582                                           r.retrieve))
583         return d
584
585     def test_corrupt_all_verbyte(self):
586         # when the version byte is not 0, we hit an assertion error in
587         # unpack_share().
588         return self._corrupt_all(0, "AssertionError")
589
590     def test_corrupt_all_seqnum(self):
591         # a corrupt sequence number will trigger a bad signature
592         return self._corrupt_all(1, "signature is invalid")
593
594     def test_corrupt_all_R(self):
595         # a corrupt root hash will trigger a bad signature
596         return self._corrupt_all(9, "signature is invalid")
597
598     def test_corrupt_all_IV(self):
599         # a corrupt salt/IV will trigger a bad signature
600         return self._corrupt_all(41, "signature is invalid")
601
602     def test_corrupt_all_k(self):
603         # a corrupt 'k' will trigger a bad signature
604         return self._corrupt_all(57, "signature is invalid")
605
606     def test_corrupt_all_N(self):
607         # a corrupt 'N' will trigger a bad signature
608         return self._corrupt_all(58, "signature is invalid")
609
610     def test_corrupt_all_segsize(self):
611         # a corrupt segsize will trigger a bad signature
612         return self._corrupt_all(59, "signature is invalid")
613
614     def test_corrupt_all_datalen(self):
615         # a corrupt data length will trigger a bad signature
616         return self._corrupt_all(67, "signature is invalid")
617
618     def test_corrupt_all_pubkey(self):
619         # a corrupt pubkey won't match the URI's fingerprint
620         return self._corrupt_all("pubkey", "pubkey doesn't match fingerprint",
621                                  refetch_pubkey=True)
622
623     def test_corrupt_all_sig(self):
624         # a corrupt signature is a bad one
625         # the signature runs from about [543:799], depending upon the length
626         # of the pubkey
627         return self._corrupt_all("signature", "signature is invalid",
628                                  refetch_pubkey=True)
629
630     def test_corrupt_all_share_hash_chain_number(self):
631         # a corrupt share hash chain entry will show up as a bad hash. If we
632         # mangle the first byte, that will look like a bad hash number,
633         # causing an IndexError
634         return self._corrupt_all("share_hash_chain", "corrupt hashes")
635
636     def test_corrupt_all_share_hash_chain_hash(self):
637         # a corrupt share hash chain entry will show up as a bad hash. If we
638         # mangle a few bytes in, that will look like a bad hash.
639         return self._corrupt_all(("share_hash_chain",4), "corrupt hashes")
640
641     def test_corrupt_all_block_hash_tree(self):
642         return self._corrupt_all("block_hash_tree", "block hash tree failure")
643
644     def test_corrupt_all_block(self):
645         return self._corrupt_all("share_data", "block hash tree failure")
646
647     def test_corrupt_all_encprivkey(self):
648         # a corrupted privkey won't even be noticed by the reader
649         return self._corrupt_all("enc_privkey", None, should_succeed=True)
650
651     def test_short_read(self):
652         c, s, fn, p, r = self.setup_for_publish(20)
653         contents = "New contents go here"
654         d = p.publish(contents)
655         def _published(res):
656             # force a short read, to make Retrieve._got_results re-send the
657             # queries. But don't make it so short that we can't read the
658             # header.
659             r._read_size = mutable.HEADER_LENGTH + 10
660             return r.retrieve()
661         d.addCallback(_published)
662         def _retrieved(new_contents):
663             self.failUnlessEqual(contents, new_contents)
664         d.addCallback(_retrieved)
665         return d
666
667     def test_basic_sequenced(self):
668         c, s, fn, p, r = self.setup_for_publish(20)
669         s._sequence = c._peerids[:]
670         contents = "New contents go here"
671         d = p.publish(contents)
672         def _published(res):
673             return r.retrieve()
674         d.addCallback(_published)
675         def _retrieved(new_contents):
676             self.failUnlessEqual(contents, new_contents)
677         d.addCallback(_retrieved)
678         return d
679
680     def test_basic_pubkey_at_end(self):
681         # we corrupt the pubkey in all but the last 'k' shares, allowing the
682         # download to succeed but forcing a bunch of retries first. Note that
683         # this is rather pessimistic: our Retrieve process will throw away
684         # the whole share if the pubkey is bad, even though the rest of the
685         # share might be good.
686         c, s, fn, p, r = self.setup_for_publish(20)
687         s._sequence = c._peerids[:]
688         contents = "New contents go here"
689         d = p.publish(contents)
690         def _published(res):
691             r._pubkey = None
692             homes = [peerid for peerid in c._peerids
693                      if s._peers.get(peerid, {})]
694             k = fn.get_required_shares()
695             homes_to_corrupt = homes[:-k]
696             for peerid in homes_to_corrupt:
697                 shares = s._peers[peerid]
698                 for shnum in shares:
699                     data = shares[shnum]
700                     (version,
701                      seqnum,
702                      root_hash,
703                      IV,
704                      k, N, segsize, datalen,
705                      o) = mutable.unpack_header(data)
706                     offset = 107 # pubkey
707                     shares[shnum] = self.flip_bit(data, offset)
708             return r.retrieve()
709         d.addCallback(_published)
710         def _retrieved(new_contents):
711             self.failUnlessEqual(contents, new_contents)
712         d.addCallback(_retrieved)
713         return d
714
715     def _encode(self, c, s, fn, k, n, data):
716         # encode 'data' into a peerid->shares dict.
717
718         fn2 = FakeFilenode(c)
719         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
720         # and _fingerprint
721         fn2.init_from_uri(fn.get_uri())
722         # then we copy over other fields that are normally fetched from the
723         # existing shares
724         fn2._pubkey = fn._pubkey
725         fn2._privkey = fn._privkey
726         fn2._encprivkey = fn._encprivkey
727         fn2._current_seqnum = 0
728         fn2._current_roothash = "\x00" * 32
729         # and set the encoding parameters to something completely different
730         fn2._required_shares = k
731         fn2._total_shares = n
732
733         p2 = FakePublish(fn2)
734         p2._storage = s
735         p2._storage._peers = {} # clear existing storage
736         d = p2.publish(data)
737         def _published(res):
738             shares = s._peers
739             s._peers = {}
740             return shares
741         d.addCallback(_published)
742         return d
743
744     def test_multiple_encodings(self):
745         # we encode the same file in two different ways (3-of-10 and 4-of-9),
746         # then mix up the shares, to make sure that download survives seeing
747         # a variety of encodings. This is actually kind of tricky to set up.
748         c, s, fn, p, r = self.setup_for_publish(20)
749         # we ignore fn, p, and r
750
751         contents1 = "Contents for encoding 1 (3-of-10) go here"
752         contents2 = "Contents for encoding 2 (4-of-9) go here"
753         contents3 = "Contents for encoding 3 (4-of-7) go here"
754
755         # we make a retrieval object that doesn't know what encoding
756         # parameters to use
757         fn3 = FakeFilenode(c)
758         fn3.init_from_uri(fn.get_uri())
759
760         # now we upload a file through fn1, and grab its shares
761         d = self._encode(c, s, fn, 3, 10, contents1)
762         def _encoded_1(shares):
763             self._shares1 = shares
764         d.addCallback(_encoded_1)
765         d.addCallback(lambda res: self._encode(c, s, fn, 4, 9, contents2))
766         def _encoded_2(shares):
767             self._shares2 = shares
768         d.addCallback(_encoded_2)
769         d.addCallback(lambda res: self._encode(c, s, fn, 4, 7, contents3))
770         def _encoded_3(shares):
771             self._shares3 = shares
772         d.addCallback(_encoded_3)
773
774         def _merge(res):
775             log.msg("merging sharelists")
776             # we merge the shares from the two sets, leaving each shnum in
777             # its original location, but using a share from set1 or set2
778             # according to the following sequence:
779             #
780             #  4-of-9  a  s2
781             #  4-of-9  b  s2
782             #  4-of-7  c   s3
783             #  4-of-9  d  s2
784             #  3-of-9  e s1
785             #  3-of-9  f s1
786             #  3-of-9  g s1
787             #  4-of-9  h  s2
788             #
789             # so that neither form can be recovered until fetch [f], at which
790             # point version-s1 (the 3-of-10 form) should be recoverable. If
791             # the implementation latches on to the first version it sees,
792             # then s2 will be recoverable at fetch [g].
793
794             # Later, when we implement code that handles multiple versions,
795             # we can use this framework to assert that all recoverable
796             # versions are retrieved, and test that 'epsilon' does its job
797
798             places = [2, 2, 3, 2, 1, 1, 1, 2]
799
800             sharemap = {}
801
802             for i,peerid in enumerate(c._peerids):
803                 peerid_s = shortnodeid_b2a(peerid)
804                 for shnum in self._shares1.get(peerid, {}):
805                     if shnum < len(places):
806                         which = places[shnum]
807                     else:
808                         which = "x"
809                     s._peers[peerid] = peers = {}
810                     in_1 = shnum in self._shares1[peerid]
811                     in_2 = shnum in self._shares2.get(peerid, {})
812                     in_3 = shnum in self._shares3.get(peerid, {})
813                     #print peerid_s, shnum, which, in_1, in_2, in_3
814                     if which == 1:
815                         if in_1:
816                             peers[shnum] = self._shares1[peerid][shnum]
817                             sharemap[shnum] = peerid
818                     elif which == 2:
819                         if in_2:
820                             peers[shnum] = self._shares2[peerid][shnum]
821                             sharemap[shnum] = peerid
822                     elif which == 3:
823                         if in_3:
824                             peers[shnum] = self._shares3[peerid][shnum]
825                             sharemap[shnum] = peerid
826
827             # we don't bother placing any other shares
828             # now sort the sequence so that share 0 is returned first
829             new_sequence = [sharemap[shnum]
830                             for shnum in sorted(sharemap.keys())]
831             s._sequence = new_sequence
832             log.msg("merge done")
833         d.addCallback(_merge)
834         def _retrieve(res):
835             r3 = FakeRetrieve(fn3)
836             r3._storage = s
837             return r3.retrieve()
838         d.addCallback(_retrieve)
839         def _retrieved(new_contents):
840             # the current specified behavior is "first version recoverable"
841             self.failUnlessEqual(new_contents, contents1)
842         d.addCallback(_retrieved)
843         return d
844