]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_mutable.py: add more tests of post-mapupdate corruption, to support #474 testing
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, download, storage
8 from allmydata.util import base32, testutil, idlib
9 from allmydata.util.idlib import shortnodeid_b2a
10 from allmydata.util.hashutil import tagged_hash
11 from allmydata.util.fileutil import make_dirs
12 from allmydata.encode import NotEnoughSharesError
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
14      FileTooLargeError
15 from foolscap.eventual import eventually, fireEventually
16 from foolscap.logging import log
17 import sha
18
19 from allmydata.mutable.node import MutableFileNode, BackoffAgent
20 from allmydata.mutable.common import DictOfSets, ResponseCache, \
21      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
22      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
23      NotEnoughServersError
24 from allmydata.mutable.retrieve import Retrieve
25 from allmydata.mutable.publish import Publish
26 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
27 from allmydata.mutable.layout import unpack_header, unpack_share
28
29 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
30 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
31 # of this.
32
33 class FastMutableFileNode(MutableFileNode):
34     SIGNATURE_KEY_SIZE = 522
35
36 # this "FakeStorage" exists to put the share data in RAM and avoid using real
37 # network connections, both to speed up the tests and to reduce the amount of
38 # non-mutable.py code being exercised.
39
40 class FakeStorage:
41     # this class replaces the collection of storage servers, allowing the
42     # tests to examine and manipulate the published shares. It also lets us
43     # control the order in which read queries are answered, to exercise more
44     # of the error-handling code in Retrieve .
45     #
46     # Note that we ignore the storage index: this FakeStorage instance can
47     # only be used for a single storage index.
48
49
50     def __init__(self):
51         self._peers = {}
52         # _sequence is used to cause the responses to occur in a specific
53         # order. If it is in use, then we will defer queries instead of
54         # answering them right away, accumulating the Deferreds in a dict. We
55         # don't know exactly how many queries we'll get, so exactly one
56         # second after the first query arrives, we will release them all (in
57         # order).
58         self._sequence = None
59         self._pending = {}
60         self._pending_timer = None
61         self._special_answers = {}
62
63     def read(self, peerid, storage_index):
64         shares = self._peers.get(peerid, {})
65         if self._special_answers.get(peerid, []):
66             mode = self._special_answers[peerid].pop(0)
67             if mode == "fail":
68                 shares = failure.Failure(IntentionalError())
69             elif mode == "none":
70                 shares = {}
71             elif mode == "normal":
72                 pass
73         if self._sequence is None:
74             return defer.succeed(shares)
75         d = defer.Deferred()
76         if not self._pending:
77             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
78         self._pending[peerid] = (d, shares)
79         return d
80
81     def _fire_readers(self):
82         self._pending_timer = None
83         pending = self._pending
84         self._pending = {}
85         extra = []
86         for peerid in self._sequence:
87             if peerid in pending:
88                 d, shares = pending.pop(peerid)
89                 eventually(d.callback, shares)
90         for (d, shares) in pending.values():
91             eventually(d.callback, shares)
92
93     def write(self, peerid, storage_index, shnum, offset, data):
94         if peerid not in self._peers:
95             self._peers[peerid] = {}
96         shares = self._peers[peerid]
97         f = StringIO()
98         f.write(shares.get(shnum, ""))
99         f.seek(offset)
100         f.write(data)
101         shares[shnum] = f.getvalue()
102
103
104 class FakeStorageServer:
105     def __init__(self, peerid, storage):
106         self.peerid = peerid
107         self.storage = storage
108         self.queries = 0
109     def callRemote(self, methname, *args, **kwargs):
110         def _call():
111             meth = getattr(self, methname)
112             return meth(*args, **kwargs)
113         d = fireEventually()
114         d.addCallback(lambda res: _call())
115         return d
116
117     def slot_readv(self, storage_index, shnums, readv):
118         d = self.storage.read(self.peerid, storage_index)
119         def _read(shares):
120             response = {}
121             for shnum in shares:
122                 if shnums and shnum not in shnums:
123                     continue
124                 vector = response[shnum] = []
125                 for (offset, length) in readv:
126                     assert isinstance(offset, (int, long)), offset
127                     assert isinstance(length, (int, long)), length
128                     vector.append(shares[shnum][offset:offset+length])
129             return response
130         d.addCallback(_read)
131         return d
132
133     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
134                                         tw_vectors, read_vector):
135         # always-pass: parrot the test vectors back to them.
136         readv = {}
137         for shnum, (testv, writev, new_length) in tw_vectors.items():
138             for (offset, length, op, specimen) in testv:
139                 assert op in ("le", "eq", "ge")
140             # TODO: this isn't right, the read is controlled by read_vector,
141             # not by testv
142             readv[shnum] = [ specimen
143                              for (offset, length, op, specimen)
144                              in testv ]
145             for (offset, data) in writev:
146                 self.storage.write(self.peerid, storage_index, shnum,
147                                    offset, data)
148         answer = (True, readv)
149         return fireEventually(answer)
150
151
152 # our "FakeClient" has just enough functionality of the real Client to let
153 # the tests run.
154
155 class FakeClient:
156     mutable_file_node_class = FastMutableFileNode
157
158     def __init__(self, num_peers=10):
159         self._storage = FakeStorage()
160         self._num_peers = num_peers
161         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
162                          for i in range(self._num_peers)]
163         self._connections = dict([(peerid, FakeStorageServer(peerid,
164                                                              self._storage))
165                                   for peerid in self._peerids])
166         self.nodeid = "fakenodeid"
167
168     def log(self, msg, **kw):
169         return log.msg(msg, **kw)
170
171     def get_renewal_secret(self):
172         return "I hereby permit you to renew my files"
173     def get_cancel_secret(self):
174         return "I hereby permit you to cancel my leases"
175
176     def create_mutable_file(self, contents=""):
177         n = self.mutable_file_node_class(self)
178         d = n.create(contents)
179         d.addCallback(lambda res: n)
180         return d
181
182     def notify_retrieve(self, r):
183         pass
184     def notify_publish(self, p, size):
185         pass
186     def notify_mapupdate(self, u):
187         pass
188
189     def create_node_from_uri(self, u):
190         u = IURI(u)
191         assert IMutableFileURI.providedBy(u), u
192         res = self.mutable_file_node_class(self).init_from_uri(u)
193         return res
194
195     def get_permuted_peers(self, service_name, key):
196         """
197         @return: list of (peerid, connection,)
198         """
199         results = []
200         for (peerid, connection) in self._connections.items():
201             assert isinstance(peerid, str)
202             permuted = sha.new(key + peerid).digest()
203             results.append((permuted, peerid, connection))
204         results.sort()
205         results = [ (r[1],r[2]) for r in results]
206         return results
207
208     def upload(self, uploadable):
209         assert IUploadable.providedBy(uploadable)
210         d = uploadable.get_size()
211         d.addCallback(lambda length: uploadable.read(length))
212         #d.addCallback(self.create_mutable_file)
213         def _got_data(datav):
214             data = "".join(datav)
215             #newnode = FastMutableFileNode(self)
216             return uri.LiteralFileURI(data)
217         d.addCallback(_got_data)
218         return d
219
220
221 def flip_bit(original, byte_offset):
222     return (original[:byte_offset] +
223             chr(ord(original[byte_offset]) ^ 0x01) +
224             original[byte_offset+1:])
225
226 def corrupt(res, s, offset, shnums_to_corrupt=None):
227     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
228     # list of shnums to corrupt.
229     for peerid in s._peers:
230         shares = s._peers[peerid]
231         for shnum in shares:
232             if (shnums_to_corrupt is not None
233                 and shnum not in shnums_to_corrupt):
234                 continue
235             data = shares[shnum]
236             (version,
237              seqnum,
238              root_hash,
239              IV,
240              k, N, segsize, datalen,
241              o) = unpack_header(data)
242             if isinstance(offset, tuple):
243                 offset1, offset2 = offset
244             else:
245                 offset1 = offset
246                 offset2 = 0
247             if offset1 == "pubkey":
248                 real_offset = 107
249             elif offset1 in o:
250                 real_offset = o[offset1]
251             else:
252                 real_offset = offset1
253             real_offset = int(real_offset) + offset2
254             assert isinstance(real_offset, int), offset
255             shares[shnum] = flip_bit(data, real_offset)
256     return res
257
258 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
259     def setUp(self):
260         self.client = FakeClient()
261
262     def test_create(self):
263         d = self.client.create_mutable_file()
264         def _created(n):
265             self.failUnless(isinstance(n, FastMutableFileNode))
266             peer0 = self.client._peerids[0]
267             shnums = self.client._storage._peers[peer0].keys()
268             self.failUnlessEqual(len(shnums), 1)
269         d.addCallback(_created)
270         return d
271
272     def test_serialize(self):
273         n = MutableFileNode(self.client)
274         calls = []
275         def _callback(*args, **kwargs):
276             self.failUnlessEqual(args, (4,) )
277             self.failUnlessEqual(kwargs, {"foo": 5})
278             calls.append(1)
279             return 6
280         d = n._do_serialized(_callback, 4, foo=5)
281         def _check_callback(res):
282             self.failUnlessEqual(res, 6)
283             self.failUnlessEqual(calls, [1])
284         d.addCallback(_check_callback)
285
286         def _errback():
287             raise ValueError("heya")
288         d.addCallback(lambda res:
289                       self.shouldFail(ValueError, "_check_errback", "heya",
290                                       n._do_serialized, _errback))
291         return d
292
293     def test_upload_and_download(self):
294         d = self.client.create_mutable_file()
295         def _created(n):
296             d = defer.succeed(None)
297             d.addCallback(lambda res: n.get_servermap(MODE_READ))
298             d.addCallback(lambda smap: smap.dump(StringIO()))
299             d.addCallback(lambda sio:
300                           self.failUnless("3-of-10" in sio.getvalue()))
301             d.addCallback(lambda res: n.overwrite("contents 1"))
302             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
303             d.addCallback(lambda res: n.download_best_version())
304             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
305             d.addCallback(lambda res: n.overwrite("contents 2"))
306             d.addCallback(lambda res: n.download_best_version())
307             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
308             d.addCallback(lambda res: n.download(download.Data()))
309             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
310             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
311             d.addCallback(lambda smap: n.upload("contents 3", smap))
312             d.addCallback(lambda res: n.download_best_version())
313             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
314             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
315             d.addCallback(lambda smap:
316                           n.download_version(smap,
317                                              smap.best_recoverable_version()))
318             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
319             # test a file that is large enough to overcome the
320             # mapupdate-to-retrieve data caching (i.e. make the shares larger
321             # than the default readsize, which is 2000 bytes). A 15kB file
322             # will have 5kB shares.
323             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
324             d.addCallback(lambda res: n.download_best_version())
325             d.addCallback(lambda res:
326                           self.failUnlessEqual(res, "large size file" * 1000))
327             return d
328         d.addCallback(_created)
329         return d
330
331     def test_create_with_initial_contents(self):
332         d = self.client.create_mutable_file("contents 1")
333         def _created(n):
334             d = n.download_best_version()
335             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
336             d.addCallback(lambda res: n.overwrite("contents 2"))
337             d.addCallback(lambda res: n.download_best_version())
338             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
339             return d
340         d.addCallback(_created)
341         return d
342
343     def test_create_with_too_large_contents(self):
344         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
345         d = self.shouldFail(FileTooLargeError, "too_large",
346                             "SDMF is limited to one segment, and %d > %d" %
347                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
348                             self.client.create_mutable_file, BIG)
349         d.addCallback(lambda res: self.client.create_mutable_file("small"))
350         def _created(n):
351             return self.shouldFail(FileTooLargeError, "too_large_2",
352                                    "SDMF is limited to one segment, and %d > %d" %
353                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
354                                    n.overwrite, BIG)
355         d.addCallback(_created)
356         return d
357
358     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
359         d = n.get_servermap(MODE_READ)
360         d.addCallback(lambda servermap: servermap.best_recoverable_version())
361         d.addCallback(lambda verinfo:
362                       self.failUnlessEqual(verinfo[0], expected_seqnum))
363         return d
364
365     def test_modify(self):
366         def _modifier(old_contents):
367             return old_contents + "line2"
368         def _non_modifier(old_contents):
369             return old_contents
370         def _none_modifier(old_contents):
371             return None
372         def _error_modifier(old_contents):
373             raise ValueError("oops")
374         def _toobig_modifier(old_contents):
375             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
376         calls = []
377         def _ucw_error_modifier(old_contents):
378             # simulate an UncoordinatedWriteError once
379             calls.append(1)
380             if len(calls) <= 1:
381                 raise UncoordinatedWriteError("simulated")
382             return old_contents + "line3"
383
384         d = self.client.create_mutable_file("line1")
385         def _created(n):
386             d = n.modify(_modifier)
387             d.addCallback(lambda res: n.download_best_version())
388             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
389             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
390
391             d.addCallback(lambda res: n.modify(_non_modifier))
392             d.addCallback(lambda res: n.download_best_version())
393             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
394             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
395
396             d.addCallback(lambda res: n.modify(_none_modifier))
397             d.addCallback(lambda res: n.download_best_version())
398             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
399             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
400
401             d.addCallback(lambda res:
402                           self.shouldFail(ValueError, "error_modifier", None,
403                                           n.modify, _error_modifier))
404             d.addCallback(lambda res: n.download_best_version())
405             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
406             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
407
408             d.addCallback(lambda res:
409                           self.shouldFail(FileTooLargeError, "toobig_modifier",
410                                           "SDMF is limited to one segment",
411                                           n.modify, _toobig_modifier))
412             d.addCallback(lambda res: n.download_best_version())
413             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
414             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
415
416             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
417             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
418             d.addCallback(lambda res: n.download_best_version())
419             d.addCallback(lambda res: self.failUnlessEqual(res,
420                                                            "line1line2line3"))
421             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
422
423             return d
424         d.addCallback(_created)
425         return d
426
427     def test_modify_backoffer(self):
428         def _modifier(old_contents):
429             return old_contents + "line2"
430         calls = []
431         def _ucw_error_modifier(old_contents):
432             # simulate an UncoordinatedWriteError once
433             calls.append(1)
434             if len(calls) <= 1:
435                 raise UncoordinatedWriteError("simulated")
436             return old_contents + "line3"
437         def _always_ucw_error_modifier(old_contents):
438             raise UncoordinatedWriteError("simulated")
439         def _backoff_stopper(node, f):
440             return f
441         def _backoff_pauser(node, f):
442             d = defer.Deferred()
443             reactor.callLater(0.5, d.callback, None)
444             return d
445
446         # the give-up-er will hit its maximum retry count quickly
447         giveuper = BackoffAgent()
448         giveuper._delay = 0.1
449         giveuper.factor = 1
450
451         d = self.client.create_mutable_file("line1")
452         def _created(n):
453             d = n.modify(_modifier)
454             d.addCallback(lambda res: n.download_best_version())
455             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
456             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
457
458             d.addCallback(lambda res:
459                           self.shouldFail(UncoordinatedWriteError,
460                                           "_backoff_stopper", None,
461                                           n.modify, _ucw_error_modifier,
462                                           _backoff_stopper))
463             d.addCallback(lambda res: n.download_best_version())
464             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
465             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
466
467             def _reset_ucw_error_modifier(res):
468                 calls[:] = []
469                 return res
470             d.addCallback(_reset_ucw_error_modifier)
471             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
472                                                _backoff_pauser))
473             d.addCallback(lambda res: n.download_best_version())
474             d.addCallback(lambda res: self.failUnlessEqual(res,
475                                                            "line1line2line3"))
476             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
477
478             d.addCallback(lambda res:
479                           self.shouldFail(UncoordinatedWriteError,
480                                           "giveuper", None,
481                                           n.modify, _always_ucw_error_modifier,
482                                           giveuper.delay))
483             d.addCallback(lambda res: n.download_best_version())
484             d.addCallback(lambda res: self.failUnlessEqual(res,
485                                                            "line1line2line3"))
486             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
487
488             return d
489         d.addCallback(_created)
490         return d
491
492     def test_upload_and_download_full_size_keys(self):
493         self.client.mutable_file_node_class = MutableFileNode
494         d = self.client.create_mutable_file()
495         def _created(n):
496             d = defer.succeed(None)
497             d.addCallback(lambda res: n.get_servermap(MODE_READ))
498             d.addCallback(lambda smap: smap.dump(StringIO()))
499             d.addCallback(lambda sio:
500                           self.failUnless("3-of-10" in sio.getvalue()))
501             d.addCallback(lambda res: n.overwrite("contents 1"))
502             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
503             d.addCallback(lambda res: n.download_best_version())
504             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
505             d.addCallback(lambda res: n.overwrite("contents 2"))
506             d.addCallback(lambda res: n.download_best_version())
507             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
508             d.addCallback(lambda res: n.download(download.Data()))
509             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
510             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
511             d.addCallback(lambda smap: n.upload("contents 3", smap))
512             d.addCallback(lambda res: n.download_best_version())
513             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
514             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
515             d.addCallback(lambda smap:
516                           n.download_version(smap,
517                                              smap.best_recoverable_version()))
518             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
519             return d
520         d.addCallback(_created)
521         return d
522
523
524 class MakeShares(unittest.TestCase):
525     def test_encrypt(self):
526         c = FakeClient()
527         fn = FastMutableFileNode(c)
528         CONTENTS = "some initial contents"
529         d = fn.create(CONTENTS)
530         def _created(res):
531             p = Publish(fn, None)
532             p.salt = "SALT" * 4
533             p.readkey = "\x00" * 16
534             p.newdata = CONTENTS
535             p.required_shares = 3
536             p.total_shares = 10
537             p.setup_encoding_parameters()
538             return p._encrypt_and_encode()
539         d.addCallback(_created)
540         def _done(shares_and_shareids):
541             (shares, share_ids) = shares_and_shareids
542             self.failUnlessEqual(len(shares), 10)
543             for sh in shares:
544                 self.failUnless(isinstance(sh, str))
545                 self.failUnlessEqual(len(sh), 7)
546             self.failUnlessEqual(len(share_ids), 10)
547         d.addCallback(_done)
548         return d
549
550     def test_generate(self):
551         c = FakeClient()
552         fn = FastMutableFileNode(c)
553         CONTENTS = "some initial contents"
554         d = fn.create(CONTENTS)
555         def _created(res):
556             p = Publish(fn, None)
557             self._p = p
558             p.newdata = CONTENTS
559             p.required_shares = 3
560             p.total_shares = 10
561             p.setup_encoding_parameters()
562             p._new_seqnum = 3
563             p.salt = "SALT" * 4
564             # make some fake shares
565             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
566             p._privkey = fn.get_privkey()
567             p._encprivkey = fn.get_encprivkey()
568             p._pubkey = fn.get_pubkey()
569             return p._generate_shares(shares_and_ids)
570         d.addCallback(_created)
571         def _generated(res):
572             p = self._p
573             final_shares = p.shares
574             root_hash = p.root_hash
575             self.failUnlessEqual(len(root_hash), 32)
576             self.failUnless(isinstance(final_shares, dict))
577             self.failUnlessEqual(len(final_shares), 10)
578             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
579             for i,sh in final_shares.items():
580                 self.failUnless(isinstance(sh, str))
581                 # feed the share through the unpacker as a sanity-check
582                 pieces = unpack_share(sh)
583                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
584                  pubkey, signature, share_hash_chain, block_hash_tree,
585                  share_data, enc_privkey) = pieces
586                 self.failUnlessEqual(u_seqnum, 3)
587                 self.failUnlessEqual(u_root_hash, root_hash)
588                 self.failUnlessEqual(k, 3)
589                 self.failUnlessEqual(N, 10)
590                 self.failUnlessEqual(segsize, 21)
591                 self.failUnlessEqual(datalen, len(CONTENTS))
592                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
593                 sig_material = struct.pack(">BQ32s16s BBQQ",
594                                            0, p._new_seqnum, root_hash, IV,
595                                            k, N, segsize, datalen)
596                 self.failUnless(p._pubkey.verify(sig_material, signature))
597                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
598                 self.failUnless(isinstance(share_hash_chain, dict))
599                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
600                 for shnum,share_hash in share_hash_chain.items():
601                     self.failUnless(isinstance(shnum, int))
602                     self.failUnless(isinstance(share_hash, str))
603                     self.failUnlessEqual(len(share_hash), 32)
604                 self.failUnless(isinstance(block_hash_tree, list))
605                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
606                 self.failUnlessEqual(IV, "SALT"*4)
607                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
608                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
609         d.addCallback(_generated)
610         return d
611
612     # TODO: when we publish to 20 peers, we should get one share per peer on 10
613     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
614     # when we publish to zero peers, we should get a NotEnoughSharesError
615
616 class Servermap(unittest.TestCase):
617     def setUp(self):
618         # publish a file and create shares, which can then be manipulated
619         # later.
620         num_peers = 20
621         self._client = FakeClient(num_peers)
622         self._storage = self._client._storage
623         d = self._client.create_mutable_file("New contents go here")
624         def _created(node):
625             self._fn = node
626             self._fn2 = self._client.create_node_from_uri(node.get_uri())
627         d.addCallback(_created)
628         return d
629
630     def make_servermap(self, mode=MODE_CHECK, fn=None):
631         if fn is None:
632             fn = self._fn
633         smu = ServermapUpdater(fn, ServerMap(), mode)
634         d = smu.update()
635         return d
636
637     def update_servermap(self, oldmap, mode=MODE_CHECK):
638         smu = ServermapUpdater(self._fn, oldmap, mode)
639         d = smu.update()
640         return d
641
642     def failUnlessOneRecoverable(self, sm, num_shares):
643         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
644         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
645         best = sm.best_recoverable_version()
646         self.failIfEqual(best, None)
647         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
648         self.failUnlessEqual(len(sm.shares_available()), 1)
649         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3))
650         shnum, peerids = sm.make_sharemap().items()[0]
651         peerid = list(peerids)[0]
652         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
653         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
654         return sm
655
656     def test_basic(self):
657         d = defer.succeed(None)
658         ms = self.make_servermap
659         us = self.update_servermap
660
661         d.addCallback(lambda res: ms(mode=MODE_CHECK))
662         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
663         d.addCallback(lambda res: ms(mode=MODE_WRITE))
664         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
665         d.addCallback(lambda res: ms(mode=MODE_READ))
666         # this more stops at k+epsilon, and epsilon=k, so 6 shares
667         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
668         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
669         # this mode stops at 'k' shares
670         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
671
672         # and can we re-use the same servermap? Note that these are sorted in
673         # increasing order of number of servers queried, since once a server
674         # gets into the servermap, we'll always ask it for an update.
675         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
676         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
677         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
678         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
679         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
680         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
681         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
682         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
683         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
684
685         return d
686
687     def test_fetch_privkey(self):
688         d = defer.succeed(None)
689         # use the sibling filenode (which hasn't been used yet), and make
690         # sure it can fetch the privkey. The file is small, so the privkey
691         # will be fetched on the first (query) pass.
692         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
693         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
694
695         # create a new file, which is large enough to knock the privkey out
696         # of the early part of the file
697         LARGE = "These are Larger contents" * 200 # about 5KB
698         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
699         def _created(large_fn):
700             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
701             return self.make_servermap(MODE_WRITE, large_fn2)
702         d.addCallback(_created)
703         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
704         return d
705
706     def test_mark_bad(self):
707         d = defer.succeed(None)
708         ms = self.make_servermap
709         us = self.update_servermap
710
711         d.addCallback(lambda res: ms(mode=MODE_READ))
712         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
713         def _made_map(sm):
714             v = sm.best_recoverable_version()
715             vm = sm.make_versionmap()
716             shares = list(vm[v])
717             self.failUnlessEqual(len(shares), 6)
718             self._corrupted = set()
719             # mark the first 5 shares as corrupt, then update the servermap.
720             # The map should not have the marked shares it in any more, and
721             # new shares should be found to replace the missing ones.
722             for (shnum, peerid, timestamp) in shares:
723                 if shnum < 5:
724                     self._corrupted.add( (peerid, shnum) )
725                     sm.mark_bad_share(peerid, shnum)
726             return self.update_servermap(sm, MODE_WRITE)
727         d.addCallback(_made_map)
728         def _check_map(sm):
729             # this should find all 5 shares that weren't marked bad
730             v = sm.best_recoverable_version()
731             vm = sm.make_versionmap()
732             shares = list(vm[v])
733             for (peerid, shnum) in self._corrupted:
734                 peer_shares = sm.shares_on_peer(peerid)
735                 self.failIf(shnum in peer_shares,
736                             "%d was in %s" % (shnum, peer_shares))
737             self.failUnlessEqual(len(shares), 5)
738         d.addCallback(_check_map)
739         return d
740
741     def failUnlessNoneRecoverable(self, sm):
742         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
743         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
744         best = sm.best_recoverable_version()
745         self.failUnlessEqual(best, None)
746         self.failUnlessEqual(len(sm.shares_available()), 0)
747
748     def test_no_shares(self):
749         self._client._storage._peers = {} # delete all shares
750         ms = self.make_servermap
751         d = defer.succeed(None)
752
753         d.addCallback(lambda res: ms(mode=MODE_CHECK))
754         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
755
756         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
757         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
758
759         d.addCallback(lambda res: ms(mode=MODE_WRITE))
760         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
761
762         d.addCallback(lambda res: ms(mode=MODE_READ))
763         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
764
765         return d
766
767     def failUnlessNotQuiteEnough(self, sm):
768         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
769         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
770         best = sm.best_recoverable_version()
771         self.failUnlessEqual(best, None)
772         self.failUnlessEqual(len(sm.shares_available()), 1)
773         self.failUnlessEqual(sm.shares_available().values()[0], (2,3) )
774         return sm
775
776     def test_not_quite_enough_shares(self):
777         s = self._client._storage
778         ms = self.make_servermap
779         num_shares = len(s._peers)
780         for peerid in s._peers:
781             s._peers[peerid] = {}
782             num_shares -= 1
783             if num_shares == 2:
784                 break
785         # now there ought to be only two shares left
786         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
787
788         d = defer.succeed(None)
789
790         d.addCallback(lambda res: ms(mode=MODE_CHECK))
791         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
792         d.addCallback(lambda sm:
793                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
794         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
795         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
796         d.addCallback(lambda res: ms(mode=MODE_WRITE))
797         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
798         d.addCallback(lambda res: ms(mode=MODE_READ))
799         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
800
801         return d
802
803
804
805 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
806     def setUp(self):
807         # publish a file and create shares, which can then be manipulated
808         # later.
809         self.CONTENTS = "New contents go here" * 1000
810         num_peers = 20
811         self._client = FakeClient(num_peers)
812         self._storage = self._client._storage
813         d = self._client.create_mutable_file(self.CONTENTS)
814         def _created(node):
815             self._fn = node
816         d.addCallback(_created)
817         return d
818
819     def make_servermap(self, mode=MODE_READ, oldmap=None):
820         if oldmap is None:
821             oldmap = ServerMap()
822         smu = ServermapUpdater(self._fn, oldmap, mode)
823         d = smu.update()
824         return d
825
826     def abbrev_verinfo(self, verinfo):
827         if verinfo is None:
828             return None
829         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
830          offsets_tuple) = verinfo
831         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
832
833     def abbrev_verinfo_dict(self, verinfo_d):
834         output = {}
835         for verinfo,value in verinfo_d.items():
836             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
837              offsets_tuple) = verinfo
838             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
839         return output
840
841     def dump_servermap(self, servermap):
842         print "SERVERMAP", servermap
843         print "RECOVERABLE", [self.abbrev_verinfo(v)
844                               for v in servermap.recoverable_versions()]
845         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
846         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
847
848     def do_download(self, servermap, version=None):
849         if version is None:
850             version = servermap.best_recoverable_version()
851         r = Retrieve(self._fn, servermap, version)
852         return r.download()
853
854     def test_basic(self):
855         d = self.make_servermap()
856         def _do_retrieve(servermap):
857             self._smap = servermap
858             #self.dump_servermap(servermap)
859             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
860             return self.do_download(servermap)
861         d.addCallback(_do_retrieve)
862         def _retrieved(new_contents):
863             self.failUnlessEqual(new_contents, self.CONTENTS)
864         d.addCallback(_retrieved)
865         # we should be able to re-use the same servermap, both with and
866         # without updating it.
867         d.addCallback(lambda res: self.do_download(self._smap))
868         d.addCallback(_retrieved)
869         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
870         d.addCallback(lambda res: self.do_download(self._smap))
871         d.addCallback(_retrieved)
872         # clobbering the pubkey should make the servermap updater re-fetch it
873         def _clobber_pubkey(res):
874             self._fn._pubkey = None
875         d.addCallback(_clobber_pubkey)
876         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
877         d.addCallback(lambda res: self.do_download(self._smap))
878         d.addCallback(_retrieved)
879         return d
880
881
882     def _test_corrupt_all(self, offset, substring,
883                           should_succeed=False, corrupt_early=True,
884                           failure_checker=None):
885         d = defer.succeed(None)
886         if corrupt_early:
887             d.addCallback(corrupt, self._storage, offset)
888         d.addCallback(lambda res: self.make_servermap())
889         if not corrupt_early:
890             d.addCallback(corrupt, self._storage, offset)
891         def _do_retrieve(servermap):
892             ver = servermap.best_recoverable_version()
893             if ver is None and not should_succeed:
894                 # no recoverable versions == not succeeding. The problem
895                 # should be noted in the servermap's list of problems.
896                 if substring:
897                     allproblems = [str(f) for f in servermap.problems]
898                     self.failUnless(substring in "".join(allproblems))
899                 return servermap
900             if should_succeed:
901                 d1 = self._fn.download_version(servermap, ver)
902                 d1.addCallback(lambda new_contents:
903                                self.failUnlessEqual(new_contents, self.CONTENTS))
904             else:
905                 d1 = self.shouldFail(NotEnoughSharesError,
906                                      "_corrupt_all(offset=%s)" % (offset,),
907                                      substring,
908                                      self._fn.download_version, servermap, ver)
909             if failure_checker:
910                 d1.addCallback(failure_checker)
911             d1.addCallback(lambda res: servermap)
912             return d1
913         d.addCallback(_do_retrieve)
914         return d
915
916     def test_corrupt_all_verbyte(self):
917         # when the version byte is not 0, we hit an assertion error in
918         # unpack_share().
919         d = self._test_corrupt_all(0, "AssertionError")
920         def _check_servermap(servermap):
921             # and the dump should mention the problems
922             s = StringIO()
923             dump = servermap.dump(s).getvalue()
924             self.failUnless("10 PROBLEMS" in dump, dump)
925         d.addCallback(_check_servermap)
926         return d
927
928     def test_corrupt_all_seqnum(self):
929         # a corrupt sequence number will trigger a bad signature
930         return self._test_corrupt_all(1, "signature is invalid")
931
932     def test_corrupt_all_R(self):
933         # a corrupt root hash will trigger a bad signature
934         return self._test_corrupt_all(9, "signature is invalid")
935
936     def test_corrupt_all_IV(self):
937         # a corrupt salt/IV will trigger a bad signature
938         return self._test_corrupt_all(41, "signature is invalid")
939
940     def test_corrupt_all_k(self):
941         # a corrupt 'k' will trigger a bad signature
942         return self._test_corrupt_all(57, "signature is invalid")
943
944     def test_corrupt_all_N(self):
945         # a corrupt 'N' will trigger a bad signature
946         return self._test_corrupt_all(58, "signature is invalid")
947
948     def test_corrupt_all_segsize(self):
949         # a corrupt segsize will trigger a bad signature
950         return self._test_corrupt_all(59, "signature is invalid")
951
952     def test_corrupt_all_datalen(self):
953         # a corrupt data length will trigger a bad signature
954         return self._test_corrupt_all(67, "signature is invalid")
955
956     def test_corrupt_all_pubkey(self):
957         # a corrupt pubkey won't match the URI's fingerprint. We need to
958         # remove the pubkey from the filenode, or else it won't bother trying
959         # to update it.
960         self._fn._pubkey = None
961         return self._test_corrupt_all("pubkey",
962                                       "pubkey doesn't match fingerprint")
963
964     def test_corrupt_all_sig(self):
965         # a corrupt signature is a bad one
966         # the signature runs from about [543:799], depending upon the length
967         # of the pubkey
968         return self._test_corrupt_all("signature", "signature is invalid")
969
970     def test_corrupt_all_share_hash_chain_number(self):
971         # a corrupt share hash chain entry will show up as a bad hash. If we
972         # mangle the first byte, that will look like a bad hash number,
973         # causing an IndexError
974         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
975
976     def test_corrupt_all_share_hash_chain_hash(self):
977         # a corrupt share hash chain entry will show up as a bad hash. If we
978         # mangle a few bytes in, that will look like a bad hash.
979         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
980
981     def test_corrupt_all_block_hash_tree(self):
982         return self._test_corrupt_all("block_hash_tree",
983                                       "block hash tree failure")
984
985     def test_corrupt_all_block(self):
986         return self._test_corrupt_all("share_data", "block hash tree failure")
987
988     def test_corrupt_all_encprivkey(self):
989         # a corrupted privkey won't even be noticed by the reader, only by a
990         # writer.
991         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
992
993
994     def test_corrupt_all_seqnum_late(self):
995         # corrupting the seqnum between mapupdate and retrieve should result
996         # in NotEnoughSharesError, since each share will look invalid
997         def _check(res):
998             f = res[0]
999             self.failUnless(f.check(NotEnoughSharesError))
1000             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1001         return self._test_corrupt_all(1, "ran out of peers",
1002                                       corrupt_early=False,
1003                                       failure_checker=_check)
1004
1005     def test_corrupt_all_block_hash_tree_late(self):
1006         def _check(res):
1007             f = res[0]
1008             self.failUnless(f.check(NotEnoughSharesError))
1009         return self._test_corrupt_all("block_hash_tree",
1010                                       "block hash tree failure",
1011                                       corrupt_early=False,
1012                                       failure_checker=_check)
1013
1014
1015     def test_corrupt_all_block_late(self):
1016         def _check(res):
1017             f = res[0]
1018             self.failUnless(f.check(NotEnoughSharesError))
1019         return self._test_corrupt_all("share_data", "block hash tree failure",
1020                                       corrupt_early=False,
1021                                       failure_checker=_check)
1022
1023
1024     def test_basic_pubkey_at_end(self):
1025         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1026         # download to succeed but forcing a bunch of retries first. Note that
1027         # this is rather pessimistic: our Retrieve process will throw away
1028         # the whole share if the pubkey is bad, even though the rest of the
1029         # share might be good.
1030
1031         self._fn._pubkey = None
1032         k = self._fn.get_required_shares()
1033         N = self._fn.get_total_shares()
1034         d = defer.succeed(None)
1035         d.addCallback(corrupt, self._storage, "pubkey",
1036                       shnums_to_corrupt=range(0, N-k))
1037         d.addCallback(lambda res: self.make_servermap())
1038         def _do_retrieve(servermap):
1039             self.failUnless(servermap.problems)
1040             self.failUnless("pubkey doesn't match fingerprint"
1041                             in str(servermap.problems[0]))
1042             ver = servermap.best_recoverable_version()
1043             r = Retrieve(self._fn, servermap, ver)
1044             return r.download()
1045         d.addCallback(_do_retrieve)
1046         d.addCallback(lambda new_contents:
1047                       self.failUnlessEqual(new_contents, self.CONTENTS))
1048         return d
1049
1050     def test_corrupt_some(self):
1051         # corrupt the data of first five shares (so the servermap thinks
1052         # they're good but retrieve marks them as bad), so that the
1053         # MODE_READ set of 6 will be insufficient, forcing node.download to
1054         # retry with more servers.
1055         corrupt(None, self._storage, "share_data", range(5))
1056         d = self.make_servermap()
1057         def _do_retrieve(servermap):
1058             ver = servermap.best_recoverable_version()
1059             self.failUnless(ver)
1060             return self._fn.download_best_version()
1061         d.addCallback(_do_retrieve)
1062         d.addCallback(lambda new_contents:
1063                       self.failUnlessEqual(new_contents, self.CONTENTS))
1064         return d
1065
1066     def test_download_fails(self):
1067         corrupt(None, self._storage, "signature")
1068         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1069                             "no recoverable versions",
1070                             self._fn.download_best_version)
1071         return d
1072
1073
1074 class MultipleEncodings(unittest.TestCase):
1075     def setUp(self):
1076         self.CONTENTS = "New contents go here"
1077         num_peers = 20
1078         self._client = FakeClient(num_peers)
1079         self._storage = self._client._storage
1080         d = self._client.create_mutable_file(self.CONTENTS)
1081         def _created(node):
1082             self._fn = node
1083         d.addCallback(_created)
1084         return d
1085
1086     def _encode(self, k, n, data):
1087         # encode 'data' into a peerid->shares dict.
1088
1089         fn2 = FastMutableFileNode(self._client)
1090         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1091         # and _fingerprint
1092         fn = self._fn
1093         fn2.init_from_uri(fn.get_uri())
1094         # then we copy over other fields that are normally fetched from the
1095         # existing shares
1096         fn2._pubkey = fn._pubkey
1097         fn2._privkey = fn._privkey
1098         fn2._encprivkey = fn._encprivkey
1099         # and set the encoding parameters to something completely different
1100         fn2._required_shares = k
1101         fn2._total_shares = n
1102
1103         s = self._client._storage
1104         s._peers = {} # clear existing storage
1105         p2 = Publish(fn2, None)
1106         d = p2.publish(data)
1107         def _published(res):
1108             shares = s._peers
1109             s._peers = {}
1110             return shares
1111         d.addCallback(_published)
1112         return d
1113
1114     def make_servermap(self, mode=MODE_READ, oldmap=None):
1115         if oldmap is None:
1116             oldmap = ServerMap()
1117         smu = ServermapUpdater(self._fn, oldmap, mode)
1118         d = smu.update()
1119         return d
1120
1121     def test_multiple_encodings(self):
1122         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1123         # then mix up the shares, to make sure that download survives seeing
1124         # a variety of encodings. This is actually kind of tricky to set up.
1125
1126         contents1 = "Contents for encoding 1 (3-of-10) go here"
1127         contents2 = "Contents for encoding 2 (4-of-9) go here"
1128         contents3 = "Contents for encoding 3 (4-of-7) go here"
1129
1130         # we make a retrieval object that doesn't know what encoding
1131         # parameters to use
1132         fn3 = FastMutableFileNode(self._client)
1133         fn3.init_from_uri(self._fn.get_uri())
1134
1135         # now we upload a file through fn1, and grab its shares
1136         d = self._encode(3, 10, contents1)
1137         def _encoded_1(shares):
1138             self._shares1 = shares
1139         d.addCallback(_encoded_1)
1140         d.addCallback(lambda res: self._encode(4, 9, contents2))
1141         def _encoded_2(shares):
1142             self._shares2 = shares
1143         d.addCallback(_encoded_2)
1144         d.addCallback(lambda res: self._encode(4, 7, contents3))
1145         def _encoded_3(shares):
1146             self._shares3 = shares
1147         d.addCallback(_encoded_3)
1148
1149         def _merge(res):
1150             log.msg("merging sharelists")
1151             # we merge the shares from the two sets, leaving each shnum in
1152             # its original location, but using a share from set1 or set2
1153             # according to the following sequence:
1154             #
1155             #  4-of-9  a  s2
1156             #  4-of-9  b  s2
1157             #  4-of-7  c   s3
1158             #  4-of-9  d  s2
1159             #  3-of-9  e s1
1160             #  3-of-9  f s1
1161             #  3-of-9  g s1
1162             #  4-of-9  h  s2
1163             #
1164             # so that neither form can be recovered until fetch [f], at which
1165             # point version-s1 (the 3-of-10 form) should be recoverable. If
1166             # the implementation latches on to the first version it sees,
1167             # then s2 will be recoverable at fetch [g].
1168
1169             # Later, when we implement code that handles multiple versions,
1170             # we can use this framework to assert that all recoverable
1171             # versions are retrieved, and test that 'epsilon' does its job
1172
1173             places = [2, 2, 3, 2, 1, 1, 1, 2]
1174
1175             sharemap = {}
1176
1177             for i,peerid in enumerate(self._client._peerids):
1178                 peerid_s = shortnodeid_b2a(peerid)
1179                 for shnum in self._shares1.get(peerid, {}):
1180                     if shnum < len(places):
1181                         which = places[shnum]
1182                     else:
1183                         which = "x"
1184                     self._client._storage._peers[peerid] = peers = {}
1185                     in_1 = shnum in self._shares1[peerid]
1186                     in_2 = shnum in self._shares2.get(peerid, {})
1187                     in_3 = shnum in self._shares3.get(peerid, {})
1188                     #print peerid_s, shnum, which, in_1, in_2, in_3
1189                     if which == 1:
1190                         if in_1:
1191                             peers[shnum] = self._shares1[peerid][shnum]
1192                             sharemap[shnum] = peerid
1193                     elif which == 2:
1194                         if in_2:
1195                             peers[shnum] = self._shares2[peerid][shnum]
1196                             sharemap[shnum] = peerid
1197                     elif which == 3:
1198                         if in_3:
1199                             peers[shnum] = self._shares3[peerid][shnum]
1200                             sharemap[shnum] = peerid
1201
1202             # we don't bother placing any other shares
1203             # now sort the sequence so that share 0 is returned first
1204             new_sequence = [sharemap[shnum]
1205                             for shnum in sorted(sharemap.keys())]
1206             self._client._storage._sequence = new_sequence
1207             log.msg("merge done")
1208         d.addCallback(_merge)
1209         d.addCallback(lambda res: fn3.download_best_version())
1210         def _retrieved(new_contents):
1211             # the current specified behavior is "first version recoverable"
1212             self.failUnlessEqual(new_contents, contents1)
1213         d.addCallback(_retrieved)
1214         return d
1215
1216 class MultipleVersions(unittest.TestCase):
1217     def setUp(self):
1218         self.CONTENTS = ["Contents 0",
1219                          "Contents 1",
1220                          "Contents 2",
1221                          "Contents 3a",
1222                          "Contents 3b"]
1223         self._copied_shares = {}
1224         num_peers = 20
1225         self._client = FakeClient(num_peers)
1226         self._storage = self._client._storage
1227         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
1228         def _created(node):
1229             self._fn = node
1230             # now create multiple versions of the same file, and accumulate
1231             # their shares, so we can mix and match them later.
1232             d = defer.succeed(None)
1233             d.addCallback(self._copy_shares, 0)
1234             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
1235             d.addCallback(self._copy_shares, 1)
1236             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
1237             d.addCallback(self._copy_shares, 2)
1238             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
1239             d.addCallback(self._copy_shares, 3)
1240             # now we replace all the shares with version s3, and upload a new
1241             # version to get s4b.
1242             rollback = dict([(i,2) for i in range(10)])
1243             d.addCallback(lambda res: self._set_versions(rollback))
1244             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
1245             d.addCallback(self._copy_shares, 4)
1246             # we leave the storage in state 4
1247             return d
1248         d.addCallback(_created)
1249         return d
1250
1251     def _copy_shares(self, ignored, index):
1252         shares = self._client._storage._peers
1253         # we need a deep copy
1254         new_shares = {}
1255         for peerid in shares:
1256             new_shares[peerid] = {}
1257             for shnum in shares[peerid]:
1258                 new_shares[peerid][shnum] = shares[peerid][shnum]
1259         self._copied_shares[index] = new_shares
1260
1261     def _set_versions(self, versionmap):
1262         # versionmap maps shnums to which version (0,1,2,3,4) we want the
1263         # share to be at. Any shnum which is left out of the map will stay at
1264         # its current version.
1265         shares = self._client._storage._peers
1266         oldshares = self._copied_shares
1267         for peerid in shares:
1268             for shnum in shares[peerid]:
1269                 if shnum in versionmap:
1270                     index = versionmap[shnum]
1271                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
1272
1273     def test_multiple_versions(self):
1274         # if we see a mix of versions in the grid, download_best_version
1275         # should get the latest one
1276         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1277         d = self._fn.download_best_version()
1278         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1279         # but if everything is at version 2, that's what we should download
1280         d.addCallback(lambda res:
1281                       self._set_versions(dict([(i,2) for i in range(10)])))
1282         d.addCallback(lambda res: self._fn.download_best_version())
1283         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1284         # if exactly one share is at version 3, we should still get v2
1285         d.addCallback(lambda res:
1286                       self._set_versions({0:3}))
1287         d.addCallback(lambda res: self._fn.download_best_version())
1288         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1289         # but the servermap should see the unrecoverable version. This
1290         # depends upon the single newer share being queried early.
1291         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1292         def _check_smap(smap):
1293             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1294             newer = smap.unrecoverable_newer_versions()
1295             self.failUnlessEqual(len(newer), 1)
1296             verinfo, health = newer.items()[0]
1297             self.failUnlessEqual(verinfo[0], 4)
1298             self.failUnlessEqual(health, (1,3))
1299             self.failIf(smap.needs_merge())
1300         d.addCallback(_check_smap)
1301         # if we have a mix of two parallel versions (s4a and s4b), we could
1302         # recover either
1303         d.addCallback(lambda res:
1304                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1305                                           1:4,3:4,5:4,7:4,9:4}))
1306         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1307         def _check_smap_mixed(smap):
1308             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1309             newer = smap.unrecoverable_newer_versions()
1310             self.failUnlessEqual(len(newer), 0)
1311             self.failUnless(smap.needs_merge())
1312         d.addCallback(_check_smap_mixed)
1313         d.addCallback(lambda res: self._fn.download_best_version())
1314         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1315                                                   res == self.CONTENTS[4]))
1316         return d
1317
1318     def test_replace(self):
1319         # if we see a mix of versions in the grid, we should be able to
1320         # replace them all with a newer version
1321
1322         # if exactly one share is at version 3, we should download (and
1323         # replace) v2, and the result should be v4. Note that the index we
1324         # give to _set_versions is different than the sequence number.
1325         target = dict([(i,2) for i in range(10)]) # seqnum3
1326         target[0] = 3 # seqnum4
1327         self._set_versions(target)
1328
1329         def _modify(oldversion):
1330             return oldversion + " modified"
1331         d = self._fn.modify(_modify)
1332         d.addCallback(lambda res: self._fn.download_best_version())
1333         expected = self.CONTENTS[2] + " modified"
1334         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1335         # and the servermap should indicate that the outlier was replaced too
1336         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1337         def _check_smap(smap):
1338             self.failUnlessEqual(smap.highest_seqnum(), 5)
1339             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1340             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1341         d.addCallback(_check_smap)
1342         return d
1343
1344
1345 class Utils(unittest.TestCase):
1346     def test_dict_of_sets(self):
1347         ds = DictOfSets()
1348         ds.add(1, "a")
1349         ds.add(2, "b")
1350         ds.add(2, "b")
1351         ds.add(2, "c")
1352         self.failUnlessEqual(ds[1], set(["a"]))
1353         self.failUnlessEqual(ds[2], set(["b", "c"]))
1354         ds.discard(3, "d") # should not raise an exception
1355         ds.discard(2, "b")
1356         self.failUnlessEqual(ds[2], set(["c"]))
1357         ds.discard(2, "c")
1358         self.failIf(2 in ds)
1359
1360     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1361         # we compare this against sets of integers
1362         x = set(range(x_start, x_start+x_length))
1363         y = set(range(y_start, y_start+y_length))
1364         should_be_inside = x.issubset(y)
1365         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1366                                                          y_start, y_length),
1367                              str((x_start, x_length, y_start, y_length)))
1368
1369     def test_cache_inside(self):
1370         c = ResponseCache()
1371         x_start = 10
1372         x_length = 5
1373         for y_start in range(8, 17):
1374             for y_length in range(8):
1375                 self._do_inside(c, x_start, x_length, y_start, y_length)
1376
1377     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1378         # we compare this against sets of integers
1379         x = set(range(x_start, x_start+x_length))
1380         y = set(range(y_start, y_start+y_length))
1381         overlap = bool(x.intersection(y))
1382         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1383                                                       y_start, y_length),
1384                              str((x_start, x_length, y_start, y_length)))
1385
1386     def test_cache_overlap(self):
1387         c = ResponseCache()
1388         x_start = 10
1389         x_length = 5
1390         for y_start in range(8, 17):
1391             for y_length in range(8):
1392                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1393
1394     def test_cache(self):
1395         c = ResponseCache()
1396         # xdata = base62.b2a(os.urandom(100))[:100]
1397         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1398         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1399         nope = (None, None)
1400         c.add("v1", 1, 0, xdata, "time0")
1401         c.add("v1", 1, 2000, ydata, "time1")
1402         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1403         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1404         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1405         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1406         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1407         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1408         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1409         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1410         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1411         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1412         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1413         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1414         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1415         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1416         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1417         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1418         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1419         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1420
1421         # optional: join fragments
1422         c = ResponseCache()
1423         c.add("v1", 1, 0, xdata[:10], "time0")
1424         c.add("v1", 1, 10, xdata[10:20], "time1")
1425         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1426
1427 class Exceptions(unittest.TestCase):
1428     def test_repr(self):
1429         nmde = NeedMoreDataError(100, 50, 100)
1430         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1431         ucwe = UncoordinatedWriteError()
1432         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1433
1434 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1435 # instances which always succeed. So we need a less-fake one.
1436
1437 class IntentionalError(Exception):
1438     pass
1439
1440 class LocalWrapper:
1441     def __init__(self, original):
1442         self.original = original
1443         self.broken = False
1444         self.post_call_notifier = None
1445     def callRemote(self, methname, *args, **kwargs):
1446         def _call():
1447             if self.broken:
1448                 raise IntentionalError("I was asked to break")
1449             meth = getattr(self.original, "remote_" + methname)
1450             return meth(*args, **kwargs)
1451         d = fireEventually()
1452         d.addCallback(lambda res: _call())
1453         if self.post_call_notifier:
1454             d.addCallback(self.post_call_notifier, methname)
1455         return d
1456
1457 class LessFakeClient(FakeClient):
1458
1459     def __init__(self, basedir, num_peers=10):
1460         self._num_peers = num_peers
1461         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1462                          for i in range(self._num_peers)]
1463         self._connections = {}
1464         for peerid in self._peerids:
1465             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1466             make_dirs(peerdir)
1467             ss = storage.StorageServer(peerdir)
1468             ss.setNodeID(peerid)
1469             lw = LocalWrapper(ss)
1470             self._connections[peerid] = lw
1471         self.nodeid = "fakenodeid"
1472
1473
1474 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1475     def test_publish_surprise(self):
1476         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1477         self.client = LessFakeClient(basedir)
1478         d = self.client.create_mutable_file("contents 1")
1479         def _created(n):
1480             d = defer.succeed(None)
1481             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1482             def _got_smap1(smap):
1483                 # stash the old state of the file
1484                 self.old_map = smap
1485             d.addCallback(_got_smap1)
1486             # then modify the file, leaving the old map untouched
1487             d.addCallback(lambda res: log.msg("starting winning write"))
1488             d.addCallback(lambda res: n.overwrite("contents 2"))
1489             # now attempt to modify the file with the old servermap. This
1490             # will look just like an uncoordinated write, in which every
1491             # single share got updated between our mapupdate and our publish
1492             d.addCallback(lambda res: log.msg("starting doomed write"))
1493             d.addCallback(lambda res:
1494                           self.shouldFail(UncoordinatedWriteError,
1495                                           "test_publish_surprise", None,
1496                                           n.upload,
1497                                           "contents 2a", self.old_map))
1498             return d
1499         d.addCallback(_created)
1500         return d
1501
1502     def test_retrieve_surprise(self):
1503         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1504         self.client = LessFakeClient(basedir)
1505         d = self.client.create_mutable_file("contents 1")
1506         def _created(n):
1507             d = defer.succeed(None)
1508             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1509             def _got_smap1(smap):
1510                 # stash the old state of the file
1511                 self.old_map = smap
1512             d.addCallback(_got_smap1)
1513             # then modify the file, leaving the old map untouched
1514             d.addCallback(lambda res: log.msg("starting winning write"))
1515             d.addCallback(lambda res: n.overwrite("contents 2"))
1516             # now attempt to retrieve the old version with the old servermap.
1517             # This will look like someone has changed the file since we
1518             # updated the servermap.
1519             d.addCallback(lambda res: n._cache._clear())
1520             d.addCallback(lambda res: log.msg("starting doomed read"))
1521             d.addCallback(lambda res:
1522                           self.shouldFail(NotEnoughSharesError,
1523                                           "test_retrieve_surprise",
1524                                           "ran out of peers: have 0 shares (k=3)",
1525                                           n.download_version,
1526                                           self.old_map,
1527                                           self.old_map.best_recoverable_version(),
1528                                           ))
1529             return d
1530         d.addCallback(_created)
1531         return d
1532
1533     def test_unexpected_shares(self):
1534         # upload the file, take a servermap, shut down one of the servers,
1535         # upload it again (causing shares to appear on a new server), then
1536         # upload using the old servermap. The last upload should fail with an
1537         # UncoordinatedWriteError, because of the shares that didn't appear
1538         # in the servermap.
1539         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1540         self.client = LessFakeClient(basedir)
1541         d = self.client.create_mutable_file("contents 1")
1542         def _created(n):
1543             d = defer.succeed(None)
1544             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1545             def _got_smap1(smap):
1546                 # stash the old state of the file
1547                 self.old_map = smap
1548                 # now shut down one of the servers
1549                 peer0 = list(smap.make_sharemap()[0])[0]
1550                 self.client._connections.pop(peer0)
1551                 # then modify the file, leaving the old map untouched
1552                 log.msg("starting winning write")
1553                 return n.overwrite("contents 2")
1554             d.addCallback(_got_smap1)
1555             # now attempt to modify the file with the old servermap. This
1556             # will look just like an uncoordinated write, in which every
1557             # single share got updated between our mapupdate and our publish
1558             d.addCallback(lambda res: log.msg("starting doomed write"))
1559             d.addCallback(lambda res:
1560                           self.shouldFail(UncoordinatedWriteError,
1561                                           "test_surprise", None,
1562                                           n.upload,
1563                                           "contents 2a", self.old_map))
1564             return d
1565         d.addCallback(_created)
1566         return d
1567
1568     def test_bad_server(self):
1569         # Break one server, then create the file: the initial publish should
1570         # complete with an alternate server. Breaking a second server should
1571         # not prevent an update from succeeding either.
1572         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1573         self.client = LessFakeClient(basedir, 20)
1574         # to make sure that one of the initial peers is broken, we have to
1575         # get creative. We create the keys, so we can figure out the storage
1576         # index, but we hold off on doing the initial publish until we've
1577         # broken the server on which the first share wants to be stored.
1578         n = FastMutableFileNode(self.client)
1579         d = defer.succeed(None)
1580         d.addCallback(n._generate_pubprivkeys)
1581         d.addCallback(n._generated)
1582         def _break_peer0(res):
1583             si = n.get_storage_index()
1584             peerlist = self.client.get_permuted_peers("storage", si)
1585             peerid0, connection0 = peerlist[0]
1586             peerid1, connection1 = peerlist[1]
1587             connection0.broken = True
1588             self.connection1 = connection1
1589         d.addCallback(_break_peer0)
1590         # now let the initial publish finally happen
1591         d.addCallback(lambda res: n._upload("contents 1", None))
1592         # that ought to work
1593         d.addCallback(lambda res: n.download_best_version())
1594         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1595         # now break the second peer
1596         def _break_peer1(res):
1597             self.connection1.broken = True
1598         d.addCallback(_break_peer1)
1599         d.addCallback(lambda res: n.overwrite("contents 2"))
1600         # that ought to work too
1601         d.addCallback(lambda res: n.download_best_version())
1602         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1603         return d
1604
1605     def test_publish_all_servers_bad(self):
1606         # Break all servers: the publish should fail
1607         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1608         self.client = LessFakeClient(basedir, 20)
1609         for connection in self.client._connections.values():
1610             connection.broken = True
1611         d = self.shouldFail(NotEnoughServersError,
1612                             "test_publish_all_servers_bad",
1613                             "Ran out of non-bad servers",
1614                             self.client.create_mutable_file, "contents")
1615         return d
1616
1617     def test_privkey_query_error(self):
1618         # when a servermap is updated with MODE_WRITE, it tries to get the
1619         # privkey. Something might go wrong during this query attempt.
1620         self.client = FakeClient(20)
1621         # we need some contents that are large enough to push the privkey out
1622         # of the early part of the file
1623         LARGE = "These are Larger contents" * 200 # about 5KB
1624         d = self.client.create_mutable_file(LARGE)
1625         def _created(n):
1626             self.uri = n.get_uri()
1627             self.n2 = self.client.create_node_from_uri(self.uri)
1628             # we start by doing a map update to figure out which is the first
1629             # server.
1630             return n.get_servermap(MODE_WRITE)
1631         d.addCallback(_created)
1632         d.addCallback(lambda res: fireEventually(res))
1633         def _got_smap1(smap):
1634             peer0 = list(smap.make_sharemap()[0])[0]
1635             # we tell the server to respond to this peer first, so that it
1636             # will be asked for the privkey first
1637             self.client._storage._sequence = [peer0]
1638             # now we make the peer fail their second query
1639             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1640         d.addCallback(_got_smap1)
1641         # now we update a servermap from a new node (which doesn't have the
1642         # privkey yet, forcing it to use a separate privkey query). Each
1643         # query response will trigger a privkey query, and since we're using
1644         # _sequence to make the peer0 response come back first, we'll send it
1645         # a privkey query first, and _sequence will again ensure that the
1646         # peer0 query will also come back before the others, and then
1647         # _special_answers will make sure that the query raises an exception.
1648         # The whole point of these hijinks is to exercise the code in
1649         # _privkey_query_failed. Note that the map-update will succeed, since
1650         # we'll just get a copy from one of the other shares.
1651         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1652         # Using FakeStorage._sequence means there will be read requests still
1653         # floating around.. wait for them to retire
1654         def _cancel_timer(res):
1655             if self.client._storage._pending_timer:
1656                 self.client._storage._pending_timer.cancel()
1657             return res
1658         d.addBoth(_cancel_timer)
1659         return d
1660
1661     def test_privkey_query_missing(self):
1662         # like test_privkey_query_error, but the shares are deleted by the
1663         # second query, instead of raising an exception.
1664         self.client = FakeClient(20)
1665         LARGE = "These are Larger contents" * 200 # about 5KB
1666         d = self.client.create_mutable_file(LARGE)
1667         def _created(n):
1668             self.uri = n.get_uri()
1669             self.n2 = self.client.create_node_from_uri(self.uri)
1670             return n.get_servermap(MODE_WRITE)
1671         d.addCallback(_created)
1672         d.addCallback(lambda res: fireEventually(res))
1673         def _got_smap1(smap):
1674             peer0 = list(smap.make_sharemap()[0])[0]
1675             self.client._storage._sequence = [peer0]
1676             self.client._storage._special_answers[peer0] = ["normal", "none"]
1677         d.addCallback(_got_smap1)
1678         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1679         def _cancel_timer(res):
1680             if self.client._storage._pending_timer:
1681                 self.client._storage._pending_timer.cancel()
1682             return res
1683         d.addBoth(_cancel_timer)
1684         return d