]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
Change relative imports to absolute
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from allmydata import uri, client
7 from allmydata.nodemaker import NodeMaker
8 from allmydata.util import base32
9 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
10      ssk_pubkey_fingerprint_hash
11 from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
12      NotEnoughSharesError
13 from allmydata.monitor import Monitor
14 from allmydata.test.common import ShouldFailMixin
15 from allmydata.test.no_network import GridTestMixin
16 from foolscap.api import eventually, fireEventually
17 from foolscap.logging import log
18 from allmydata.storage_client import StorageFarmBroker
19
20 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
21 from allmydata.mutable.common import ResponseCache, \
22      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
23      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
24      NotEnoughServersError, CorruptShareError
25 from allmydata.mutable.retrieve import Retrieve
26 from allmydata.mutable.publish import Publish
27 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
28 from allmydata.mutable.layout import unpack_header, unpack_share
29 from allmydata.mutable.repairer import MustForceRepairError
30
31 import allmydata.test.common_util as testutil
32
33 # this "FakeStorage" exists to put the share data in RAM and avoid using real
34 # network connections, both to speed up the tests and to reduce the amount of
35 # non-mutable.py code being exercised.
36
37 class FakeStorage:
38     # this class replaces the collection of storage servers, allowing the
39     # tests to examine and manipulate the published shares. It also lets us
40     # control the order in which read queries are answered, to exercise more
41     # of the error-handling code in Retrieve .
42     #
43     # Note that we ignore the storage index: this FakeStorage instance can
44     # only be used for a single storage index.
45
46
47     def __init__(self):
48         self._peers = {}
49         # _sequence is used to cause the responses to occur in a specific
50         # order. If it is in use, then we will defer queries instead of
51         # answering them right away, accumulating the Deferreds in a dict. We
52         # don't know exactly how many queries we'll get, so exactly one
53         # second after the first query arrives, we will release them all (in
54         # order).
55         self._sequence = None
56         self._pending = {}
57         self._pending_timer = None
58
59     def read(self, peerid, storage_index):
60         shares = self._peers.get(peerid, {})
61         if self._sequence is None:
62             return defer.succeed(shares)
63         d = defer.Deferred()
64         if not self._pending:
65             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
66         self._pending[peerid] = (d, shares)
67         return d
68
69     def _fire_readers(self):
70         self._pending_timer = None
71         pending = self._pending
72         self._pending = {}
73         for peerid in self._sequence:
74             if peerid in pending:
75                 d, shares = pending.pop(peerid)
76                 eventually(d.callback, shares)
77         for (d, shares) in pending.values():
78             eventually(d.callback, shares)
79
80     def write(self, peerid, storage_index, shnum, offset, data):
81         if peerid not in self._peers:
82             self._peers[peerid] = {}
83         shares = self._peers[peerid]
84         f = StringIO()
85         f.write(shares.get(shnum, ""))
86         f.seek(offset)
87         f.write(data)
88         shares[shnum] = f.getvalue()
89
90
91 class FakeStorageServer:
92     def __init__(self, peerid, storage):
93         self.peerid = peerid
94         self.storage = storage
95         self.queries = 0
96     def callRemote(self, methname, *args, **kwargs):
97         def _call():
98             meth = getattr(self, methname)
99             return meth(*args, **kwargs)
100         d = fireEventually()
101         d.addCallback(lambda res: _call())
102         return d
103     def callRemoteOnly(self, methname, *args, **kwargs):
104         d = self.callRemote(methname, *args, **kwargs)
105         d.addBoth(lambda ignore: None)
106         pass
107
108     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
109         pass
110
111     def slot_readv(self, storage_index, shnums, readv):
112         d = self.storage.read(self.peerid, storage_index)
113         def _read(shares):
114             response = {}
115             for shnum in shares:
116                 if shnums and shnum not in shnums:
117                     continue
118                 vector = response[shnum] = []
119                 for (offset, length) in readv:
120                     assert isinstance(offset, (int, long)), offset
121                     assert isinstance(length, (int, long)), length
122                     vector.append(shares[shnum][offset:offset+length])
123             return response
124         d.addCallback(_read)
125         return d
126
127     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
128                                         tw_vectors, read_vector):
129         # always-pass: parrot the test vectors back to them.
130         readv = {}
131         for shnum, (testv, writev, new_length) in tw_vectors.items():
132             for (offset, length, op, specimen) in testv:
133                 assert op in ("le", "eq", "ge")
134             # TODO: this isn't right, the read is controlled by read_vector,
135             # not by testv
136             readv[shnum] = [ specimen
137                              for (offset, length, op, specimen)
138                              in testv ]
139             for (offset, data) in writev:
140                 self.storage.write(self.peerid, storage_index, shnum,
141                                    offset, data)
142         answer = (True, readv)
143         return fireEventually(answer)
144
145
146 def flip_bit(original, byte_offset):
147     return (original[:byte_offset] +
148             chr(ord(original[byte_offset]) ^ 0x01) +
149             original[byte_offset+1:])
150
151 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
152     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
153     # list of shnums to corrupt.
154     for peerid in s._peers:
155         shares = s._peers[peerid]
156         for shnum in shares:
157             if (shnums_to_corrupt is not None
158                 and shnum not in shnums_to_corrupt):
159                 continue
160             data = shares[shnum]
161             (version,
162              seqnum,
163              root_hash,
164              IV,
165              k, N, segsize, datalen,
166              o) = unpack_header(data)
167             if isinstance(offset, tuple):
168                 offset1, offset2 = offset
169             else:
170                 offset1 = offset
171                 offset2 = 0
172             if offset1 == "pubkey":
173                 real_offset = 107
174             elif offset1 in o:
175                 real_offset = o[offset1]
176             else:
177                 real_offset = offset1
178             real_offset = int(real_offset) + offset2 + offset_offset
179             assert isinstance(real_offset, int), offset
180             shares[shnum] = flip_bit(data, real_offset)
181     return res
182
183 def make_storagebroker(s=None, num_peers=10):
184     if not s:
185         s = FakeStorage()
186     peerids = [tagged_hash("peerid", "%d" % i)[:20]
187                for i in range(num_peers)]
188     storage_broker = StorageFarmBroker(None, True)
189     for peerid in peerids:
190         fss = FakeStorageServer(peerid, s)
191         storage_broker.test_add_server(peerid, fss)
192     return storage_broker
193
194 def make_nodemaker(s=None, num_peers=10):
195     storage_broker = make_storagebroker(s, num_peers)
196     sh = client.SecretHolder("lease secret", "convergence secret")
197     keygen = client.KeyGenerator()
198     keygen.set_default_keysize(522)
199     nodemaker = NodeMaker(storage_broker, sh, None,
200                           None, None, None,
201                           {"k": 3, "n": 10}, keygen)
202     return nodemaker
203
204 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
205     # this used to be in Publish, but we removed the limit. Some of
206     # these tests test whether the new code correctly allows files
207     # larger than the limit.
208     OLD_MAX_SEGMENT_SIZE = 3500000
209     def setUp(self):
210         self._storage = s = FakeStorage()
211         self.nodemaker = make_nodemaker(s)
212
213     def test_create(self):
214         d = self.nodemaker.create_mutable_file()
215         def _created(n):
216             self.failUnless(isinstance(n, MutableFileNode))
217             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
218             sb = self.nodemaker.storage_broker
219             peer0 = sorted(sb.get_all_serverids())[0]
220             shnums = self._storage._peers[peer0].keys()
221             self.failUnlessEqual(len(shnums), 1)
222         d.addCallback(_created)
223         return d
224
225     def test_serialize(self):
226         n = MutableFileNode(None, None, {"k": 3, "n": 10}, None)
227         calls = []
228         def _callback(*args, **kwargs):
229             self.failUnlessEqual(args, (4,) )
230             self.failUnlessEqual(kwargs, {"foo": 5})
231             calls.append(1)
232             return 6
233         d = n._do_serialized(_callback, 4, foo=5)
234         def _check_callback(res):
235             self.failUnlessEqual(res, 6)
236             self.failUnlessEqual(calls, [1])
237         d.addCallback(_check_callback)
238
239         def _errback():
240             raise ValueError("heya")
241         d.addCallback(lambda res:
242                       self.shouldFail(ValueError, "_check_errback", "heya",
243                                       n._do_serialized, _errback))
244         return d
245
246     def test_upload_and_download(self):
247         d = self.nodemaker.create_mutable_file()
248         def _created(n):
249             d = defer.succeed(None)
250             d.addCallback(lambda res: n.get_servermap(MODE_READ))
251             d.addCallback(lambda smap: smap.dump(StringIO()))
252             d.addCallback(lambda sio:
253                           self.failUnless("3-of-10" in sio.getvalue()))
254             d.addCallback(lambda res: n.overwrite("contents 1"))
255             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
256             d.addCallback(lambda res: n.download_best_version())
257             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
258             d.addCallback(lambda res: n.get_size_of_best_version())
259             d.addCallback(lambda size:
260                           self.failUnlessEqual(size, len("contents 1")))
261             d.addCallback(lambda res: n.overwrite("contents 2"))
262             d.addCallback(lambda res: n.download_best_version())
263             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
264             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
265             d.addCallback(lambda smap: n.upload("contents 3", smap))
266             d.addCallback(lambda res: n.download_best_version())
267             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
268             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
269             d.addCallback(lambda smap:
270                           n.download_version(smap,
271                                              smap.best_recoverable_version()))
272             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
273             # test a file that is large enough to overcome the
274             # mapupdate-to-retrieve data caching (i.e. make the shares larger
275             # than the default readsize, which is 2000 bytes). A 15kB file
276             # will have 5kB shares.
277             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
278             d.addCallback(lambda res: n.download_best_version())
279             d.addCallback(lambda res:
280                           self.failUnlessEqual(res, "large size file" * 1000))
281             return d
282         d.addCallback(_created)
283         return d
284
285     def test_create_with_initial_contents(self):
286         d = self.nodemaker.create_mutable_file("contents 1")
287         def _created(n):
288             d = n.download_best_version()
289             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
290             d.addCallback(lambda res: n.overwrite("contents 2"))
291             d.addCallback(lambda res: n.download_best_version())
292             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
293             return d
294         d.addCallback(_created)
295         return d
296
297     def test_create_with_initial_contents_function(self):
298         data = "initial contents"
299         def _make_contents(n):
300             self.failUnless(isinstance(n, MutableFileNode))
301             key = n.get_writekey()
302             self.failUnless(isinstance(key, str), key)
303             self.failUnlessEqual(len(key), 16) # AES key size
304             return data
305         d = self.nodemaker.create_mutable_file(_make_contents)
306         def _created(n):
307             return n.download_best_version()
308         d.addCallback(_created)
309         d.addCallback(lambda data2: self.failUnlessEqual(data2, data))
310         return d
311
312     def test_create_with_too_large_contents(self):
313         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
314         d = self.nodemaker.create_mutable_file(BIG)
315         def _created(n):
316             d = n.overwrite(BIG)
317             return d
318         d.addCallback(_created)
319         return d
320
321     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
322         d = n.get_servermap(MODE_READ)
323         d.addCallback(lambda servermap: servermap.best_recoverable_version())
324         d.addCallback(lambda verinfo:
325                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
326         return d
327
328     def test_modify(self):
329         def _modifier(old_contents, servermap, first_time):
330             return old_contents + "line2"
331         def _non_modifier(old_contents, servermap, first_time):
332             return old_contents
333         def _none_modifier(old_contents, servermap, first_time):
334             return None
335         def _error_modifier(old_contents, servermap, first_time):
336             raise ValueError("oops")
337         def _toobig_modifier(old_contents, servermap, first_time):
338             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
339         calls = []
340         def _ucw_error_modifier(old_contents, servermap, first_time):
341             # simulate an UncoordinatedWriteError once
342             calls.append(1)
343             if len(calls) <= 1:
344                 raise UncoordinatedWriteError("simulated")
345             return old_contents + "line3"
346         def _ucw_error_non_modifier(old_contents, servermap, first_time):
347             # simulate an UncoordinatedWriteError once, and don't actually
348             # modify the contents on subsequent invocations
349             calls.append(1)
350             if len(calls) <= 1:
351                 raise UncoordinatedWriteError("simulated")
352             return old_contents
353
354         d = self.nodemaker.create_mutable_file("line1")
355         def _created(n):
356             d = n.modify(_modifier)
357             d.addCallback(lambda res: n.download_best_version())
358             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
359             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
360
361             d.addCallback(lambda res: n.modify(_non_modifier))
362             d.addCallback(lambda res: n.download_best_version())
363             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
364             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
365
366             d.addCallback(lambda res: n.modify(_none_modifier))
367             d.addCallback(lambda res: n.download_best_version())
368             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
369             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
370
371             d.addCallback(lambda res:
372                           self.shouldFail(ValueError, "error_modifier", None,
373                                           n.modify, _error_modifier))
374             d.addCallback(lambda res: n.download_best_version())
375             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
376             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
377
378
379             d.addCallback(lambda res: n.download_best_version())
380             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
381             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
382
383             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
384             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
385             d.addCallback(lambda res: n.download_best_version())
386             d.addCallback(lambda res: self.failUnlessEqual(res,
387                                                            "line1line2line3"))
388             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
389
390             def _reset_ucw_error_modifier(res):
391                 calls[:] = []
392                 return res
393             d.addCallback(_reset_ucw_error_modifier)
394
395             # in practice, this n.modify call should publish twice: the first
396             # one gets a UCWE, the second does not. But our test jig (in
397             # which the modifier raises the UCWE) skips over the first one,
398             # so in this test there will be only one publish, and the seqnum
399             # will only be one larger than the previous test, not two (i.e. 4
400             # instead of 5).
401             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
402             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
403             d.addCallback(lambda res: n.download_best_version())
404             d.addCallback(lambda res: self.failUnlessEqual(res,
405                                                            "line1line2line3"))
406             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
407             d.addCallback(lambda res: n.modify(_toobig_modifier))
408             return d
409         d.addCallback(_created)
410         return d
411
412     def test_modify_backoffer(self):
413         def _modifier(old_contents, servermap, first_time):
414             return old_contents + "line2"
415         calls = []
416         def _ucw_error_modifier(old_contents, servermap, first_time):
417             # simulate an UncoordinatedWriteError once
418             calls.append(1)
419             if len(calls) <= 1:
420                 raise UncoordinatedWriteError("simulated")
421             return old_contents + "line3"
422         def _always_ucw_error_modifier(old_contents, servermap, first_time):
423             raise UncoordinatedWriteError("simulated")
424         def _backoff_stopper(node, f):
425             return f
426         def _backoff_pauser(node, f):
427             d = defer.Deferred()
428             reactor.callLater(0.5, d.callback, None)
429             return d
430
431         # the give-up-er will hit its maximum retry count quickly
432         giveuper = BackoffAgent()
433         giveuper._delay = 0.1
434         giveuper.factor = 1
435
436         d = self.nodemaker.create_mutable_file("line1")
437         def _created(n):
438             d = n.modify(_modifier)
439             d.addCallback(lambda res: n.download_best_version())
440             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
441             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
442
443             d.addCallback(lambda res:
444                           self.shouldFail(UncoordinatedWriteError,
445                                           "_backoff_stopper", None,
446                                           n.modify, _ucw_error_modifier,
447                                           _backoff_stopper))
448             d.addCallback(lambda res: n.download_best_version())
449             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
450             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
451
452             def _reset_ucw_error_modifier(res):
453                 calls[:] = []
454                 return res
455             d.addCallback(_reset_ucw_error_modifier)
456             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
457                                                _backoff_pauser))
458             d.addCallback(lambda res: n.download_best_version())
459             d.addCallback(lambda res: self.failUnlessEqual(res,
460                                                            "line1line2line3"))
461             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
462
463             d.addCallback(lambda res:
464                           self.shouldFail(UncoordinatedWriteError,
465                                           "giveuper", None,
466                                           n.modify, _always_ucw_error_modifier,
467                                           giveuper.delay))
468             d.addCallback(lambda res: n.download_best_version())
469             d.addCallback(lambda res: self.failUnlessEqual(res,
470                                                            "line1line2line3"))
471             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
472
473             return d
474         d.addCallback(_created)
475         return d
476
477     def test_upload_and_download_full_size_keys(self):
478         self.nodemaker.key_generator = client.KeyGenerator()
479         d = self.nodemaker.create_mutable_file()
480         def _created(n):
481             d = defer.succeed(None)
482             d.addCallback(lambda res: n.get_servermap(MODE_READ))
483             d.addCallback(lambda smap: smap.dump(StringIO()))
484             d.addCallback(lambda sio:
485                           self.failUnless("3-of-10" in sio.getvalue()))
486             d.addCallback(lambda res: n.overwrite("contents 1"))
487             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
488             d.addCallback(lambda res: n.download_best_version())
489             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
490             d.addCallback(lambda res: n.overwrite("contents 2"))
491             d.addCallback(lambda res: n.download_best_version())
492             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
493             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
494             d.addCallback(lambda smap: n.upload("contents 3", smap))
495             d.addCallback(lambda res: n.download_best_version())
496             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
497             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
498             d.addCallback(lambda smap:
499                           n.download_version(smap,
500                                              smap.best_recoverable_version()))
501             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
502             return d
503         d.addCallback(_created)
504         return d
505
506
507 class MakeShares(unittest.TestCase):
508     def test_encrypt(self):
509         nm = make_nodemaker()
510         CONTENTS = "some initial contents"
511         d = nm.create_mutable_file(CONTENTS)
512         def _created(fn):
513             p = Publish(fn, nm.storage_broker, None)
514             p.salt = "SALT" * 4
515             p.readkey = "\x00" * 16
516             p.newdata = CONTENTS
517             p.required_shares = 3
518             p.total_shares = 10
519             p.setup_encoding_parameters()
520             return p._encrypt_and_encode()
521         d.addCallback(_created)
522         def _done(shares_and_shareids):
523             (shares, share_ids) = shares_and_shareids
524             self.failUnlessEqual(len(shares), 10)
525             for sh in shares:
526                 self.failUnless(isinstance(sh, str))
527                 self.failUnlessEqual(len(sh), 7)
528             self.failUnlessEqual(len(share_ids), 10)
529         d.addCallback(_done)
530         return d
531
532     def test_generate(self):
533         nm = make_nodemaker()
534         CONTENTS = "some initial contents"
535         d = nm.create_mutable_file(CONTENTS)
536         def _created(fn):
537             self._fn = fn
538             p = Publish(fn, nm.storage_broker, None)
539             self._p = p
540             p.newdata = CONTENTS
541             p.required_shares = 3
542             p.total_shares = 10
543             p.setup_encoding_parameters()
544             p._new_seqnum = 3
545             p.salt = "SALT" * 4
546             # make some fake shares
547             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
548             p._privkey = fn.get_privkey()
549             p._encprivkey = fn.get_encprivkey()
550             p._pubkey = fn.get_pubkey()
551             return p._generate_shares(shares_and_ids)
552         d.addCallback(_created)
553         def _generated(res):
554             p = self._p
555             final_shares = p.shares
556             root_hash = p.root_hash
557             self.failUnlessEqual(len(root_hash), 32)
558             self.failUnless(isinstance(final_shares, dict))
559             self.failUnlessEqual(len(final_shares), 10)
560             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
561             for i,sh in final_shares.items():
562                 self.failUnless(isinstance(sh, str))
563                 # feed the share through the unpacker as a sanity-check
564                 pieces = unpack_share(sh)
565                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
566                  pubkey, signature, share_hash_chain, block_hash_tree,
567                  share_data, enc_privkey) = pieces
568                 self.failUnlessEqual(u_seqnum, 3)
569                 self.failUnlessEqual(u_root_hash, root_hash)
570                 self.failUnlessEqual(k, 3)
571                 self.failUnlessEqual(N, 10)
572                 self.failUnlessEqual(segsize, 21)
573                 self.failUnlessEqual(datalen, len(CONTENTS))
574                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
575                 sig_material = struct.pack(">BQ32s16s BBQQ",
576                                            0, p._new_seqnum, root_hash, IV,
577                                            k, N, segsize, datalen)
578                 self.failUnless(p._pubkey.verify(sig_material, signature))
579                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
580                 self.failUnless(isinstance(share_hash_chain, dict))
581                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
582                 for shnum,share_hash in share_hash_chain.items():
583                     self.failUnless(isinstance(shnum, int))
584                     self.failUnless(isinstance(share_hash, str))
585                     self.failUnlessEqual(len(share_hash), 32)
586                 self.failUnless(isinstance(block_hash_tree, list))
587                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
588                 self.failUnlessEqual(IV, "SALT"*4)
589                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
590                 self.failUnlessEqual(enc_privkey, self._fn.get_encprivkey())
591         d.addCallback(_generated)
592         return d
593
594     # TODO: when we publish to 20 peers, we should get one share per peer on 10
595     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
596     # when we publish to zero peers, we should get a NotEnoughSharesError
597
598 class PublishMixin:
599     def publish_one(self):
600         # publish a file and create shares, which can then be manipulated
601         # later.
602         self.CONTENTS = "New contents go here" * 1000
603         self._storage = FakeStorage()
604         self._nodemaker = make_nodemaker(self._storage)
605         self._storage_broker = self._nodemaker.storage_broker
606         d = self._nodemaker.create_mutable_file(self.CONTENTS)
607         def _created(node):
608             self._fn = node
609             self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
610         d.addCallback(_created)
611         return d
612
613     def publish_multiple(self):
614         self.CONTENTS = ["Contents 0",
615                          "Contents 1",
616                          "Contents 2",
617                          "Contents 3a",
618                          "Contents 3b"]
619         self._copied_shares = {}
620         self._storage = FakeStorage()
621         self._nodemaker = make_nodemaker(self._storage)
622         d = self._nodemaker.create_mutable_file(self.CONTENTS[0]) # seqnum=1
623         def _created(node):
624             self._fn = node
625             # now create multiple versions of the same file, and accumulate
626             # their shares, so we can mix and match them later.
627             d = defer.succeed(None)
628             d.addCallback(self._copy_shares, 0)
629             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
630             d.addCallback(self._copy_shares, 1)
631             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
632             d.addCallback(self._copy_shares, 2)
633             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
634             d.addCallback(self._copy_shares, 3)
635             # now we replace all the shares with version s3, and upload a new
636             # version to get s4b.
637             rollback = dict([(i,2) for i in range(10)])
638             d.addCallback(lambda res: self._set_versions(rollback))
639             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
640             d.addCallback(self._copy_shares, 4)
641             # we leave the storage in state 4
642             return d
643         d.addCallback(_created)
644         return d
645
646     def _copy_shares(self, ignored, index):
647         shares = self._storage._peers
648         # we need a deep copy
649         new_shares = {}
650         for peerid in shares:
651             new_shares[peerid] = {}
652             for shnum in shares[peerid]:
653                 new_shares[peerid][shnum] = shares[peerid][shnum]
654         self._copied_shares[index] = new_shares
655
656     def _set_versions(self, versionmap):
657         # versionmap maps shnums to which version (0,1,2,3,4) we want the
658         # share to be at. Any shnum which is left out of the map will stay at
659         # its current version.
660         shares = self._storage._peers
661         oldshares = self._copied_shares
662         for peerid in shares:
663             for shnum in shares[peerid]:
664                 if shnum in versionmap:
665                     index = versionmap[shnum]
666                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
667
668
669 class Servermap(unittest.TestCase, PublishMixin):
670     def setUp(self):
671         return self.publish_one()
672
673     def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None):
674         if fn is None:
675             fn = self._fn
676         if sb is None:
677             sb = self._storage_broker
678         smu = ServermapUpdater(fn, sb, Monitor(),
679                                ServerMap(), mode)
680         d = smu.update()
681         return d
682
683     def update_servermap(self, oldmap, mode=MODE_CHECK):
684         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
685                                oldmap, mode)
686         d = smu.update()
687         return d
688
689     def failUnlessOneRecoverable(self, sm, num_shares):
690         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
691         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
692         best = sm.best_recoverable_version()
693         self.failIfEqual(best, None)
694         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
695         self.failUnlessEqual(len(sm.shares_available()), 1)
696         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
697         shnum, peerids = sm.make_sharemap().items()[0]
698         peerid = list(peerids)[0]
699         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
700         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
701         return sm
702
703     def test_basic(self):
704         d = defer.succeed(None)
705         ms = self.make_servermap
706         us = self.update_servermap
707
708         d.addCallback(lambda res: ms(mode=MODE_CHECK))
709         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
710         d.addCallback(lambda res: ms(mode=MODE_WRITE))
711         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
712         d.addCallback(lambda res: ms(mode=MODE_READ))
713         # this mode stops at k+epsilon, and epsilon=k, so 6 shares
714         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
715         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
716         # this mode stops at 'k' shares
717         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
718
719         # and can we re-use the same servermap? Note that these are sorted in
720         # increasing order of number of servers queried, since once a server
721         # gets into the servermap, we'll always ask it for an update.
722         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
723         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
724         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
725         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
726         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
727         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
728         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
729         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
730         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
731
732         return d
733
734     def test_fetch_privkey(self):
735         d = defer.succeed(None)
736         # use the sibling filenode (which hasn't been used yet), and make
737         # sure it can fetch the privkey. The file is small, so the privkey
738         # will be fetched on the first (query) pass.
739         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
740         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
741
742         # create a new file, which is large enough to knock the privkey out
743         # of the early part of the file
744         LARGE = "These are Larger contents" * 200 # about 5KB
745         d.addCallback(lambda res: self._nodemaker.create_mutable_file(LARGE))
746         def _created(large_fn):
747             large_fn2 = self._nodemaker.create_from_cap(large_fn.get_uri())
748             return self.make_servermap(MODE_WRITE, large_fn2)
749         d.addCallback(_created)
750         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
751         return d
752
753     def test_mark_bad(self):
754         d = defer.succeed(None)
755         ms = self.make_servermap
756
757         d.addCallback(lambda res: ms(mode=MODE_READ))
758         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
759         def _made_map(sm):
760             v = sm.best_recoverable_version()
761             vm = sm.make_versionmap()
762             shares = list(vm[v])
763             self.failUnlessEqual(len(shares), 6)
764             self._corrupted = set()
765             # mark the first 5 shares as corrupt, then update the servermap.
766             # The map should not have the marked shares it in any more, and
767             # new shares should be found to replace the missing ones.
768             for (shnum, peerid, timestamp) in shares:
769                 if shnum < 5:
770                     self._corrupted.add( (peerid, shnum) )
771                     sm.mark_bad_share(peerid, shnum, "")
772             return self.update_servermap(sm, MODE_WRITE)
773         d.addCallback(_made_map)
774         def _check_map(sm):
775             # this should find all 5 shares that weren't marked bad
776             v = sm.best_recoverable_version()
777             vm = sm.make_versionmap()
778             shares = list(vm[v])
779             for (peerid, shnum) in self._corrupted:
780                 peer_shares = sm.shares_on_peer(peerid)
781                 self.failIf(shnum in peer_shares,
782                             "%d was in %s" % (shnum, peer_shares))
783             self.failUnlessEqual(len(shares), 5)
784         d.addCallback(_check_map)
785         return d
786
787     def failUnlessNoneRecoverable(self, sm):
788         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
789         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
790         best = sm.best_recoverable_version()
791         self.failUnlessEqual(best, None)
792         self.failUnlessEqual(len(sm.shares_available()), 0)
793
794     def test_no_shares(self):
795         self._storage._peers = {} # delete all shares
796         ms = self.make_servermap
797         d = defer.succeed(None)
798
799         d.addCallback(lambda res: ms(mode=MODE_CHECK))
800         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
801
802         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
803         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
804
805         d.addCallback(lambda res: ms(mode=MODE_WRITE))
806         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
807
808         d.addCallback(lambda res: ms(mode=MODE_READ))
809         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
810
811         return d
812
813     def failUnlessNotQuiteEnough(self, sm):
814         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
815         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
816         best = sm.best_recoverable_version()
817         self.failUnlessEqual(best, None)
818         self.failUnlessEqual(len(sm.shares_available()), 1)
819         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
820         return sm
821
822     def test_not_quite_enough_shares(self):
823         s = self._storage
824         ms = self.make_servermap
825         num_shares = len(s._peers)
826         for peerid in s._peers:
827             s._peers[peerid] = {}
828             num_shares -= 1
829             if num_shares == 2:
830                 break
831         # now there ought to be only two shares left
832         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
833
834         d = defer.succeed(None)
835
836         d.addCallback(lambda res: ms(mode=MODE_CHECK))
837         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
838         d.addCallback(lambda sm:
839                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
840         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
841         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
842         d.addCallback(lambda res: ms(mode=MODE_WRITE))
843         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
844         d.addCallback(lambda res: ms(mode=MODE_READ))
845         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
846
847         return d
848
849
850
851 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
852     def setUp(self):
853         return self.publish_one()
854
855     def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None):
856         if oldmap is None:
857             oldmap = ServerMap()
858         if sb is None:
859             sb = self._storage_broker
860         smu = ServermapUpdater(self._fn, sb, Monitor(), oldmap, mode)
861         d = smu.update()
862         return d
863
864     def abbrev_verinfo(self, verinfo):
865         if verinfo is None:
866             return None
867         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
868          offsets_tuple) = verinfo
869         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
870
871     def abbrev_verinfo_dict(self, verinfo_d):
872         output = {}
873         for verinfo,value in verinfo_d.items():
874             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
875              offsets_tuple) = verinfo
876             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
877         return output
878
879     def dump_servermap(self, servermap):
880         print "SERVERMAP", servermap
881         print "RECOVERABLE", [self.abbrev_verinfo(v)
882                               for v in servermap.recoverable_versions()]
883         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
884         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
885
886     def do_download(self, servermap, version=None):
887         if version is None:
888             version = servermap.best_recoverable_version()
889         r = Retrieve(self._fn, servermap, version)
890         return r.download()
891
892     def test_basic(self):
893         d = self.make_servermap()
894         def _do_retrieve(servermap):
895             self._smap = servermap
896             #self.dump_servermap(servermap)
897             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
898             return self.do_download(servermap)
899         d.addCallback(_do_retrieve)
900         def _retrieved(new_contents):
901             self.failUnlessEqual(new_contents, self.CONTENTS)
902         d.addCallback(_retrieved)
903         # we should be able to re-use the same servermap, both with and
904         # without updating it.
905         d.addCallback(lambda res: self.do_download(self._smap))
906         d.addCallback(_retrieved)
907         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
908         d.addCallback(lambda res: self.do_download(self._smap))
909         d.addCallback(_retrieved)
910         # clobbering the pubkey should make the servermap updater re-fetch it
911         def _clobber_pubkey(res):
912             self._fn._pubkey = None
913         d.addCallback(_clobber_pubkey)
914         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
915         d.addCallback(lambda res: self.do_download(self._smap))
916         d.addCallback(_retrieved)
917         return d
918
919     def test_all_shares_vanished(self):
920         d = self.make_servermap()
921         def _remove_shares(servermap):
922             for shares in self._storage._peers.values():
923                 shares.clear()
924             d1 = self.shouldFail(NotEnoughSharesError,
925                                  "test_all_shares_vanished",
926                                  "ran out of peers",
927                                  self.do_download, servermap)
928             return d1
929         d.addCallback(_remove_shares)
930         return d
931
932     def test_no_servers(self):
933         sb2 = make_storagebroker(num_peers=0)
934         # if there are no servers, then a MODE_READ servermap should come
935         # back empty
936         d = self.make_servermap(sb=sb2)
937         def _check_servermap(servermap):
938             self.failUnlessEqual(servermap.best_recoverable_version(), None)
939             self.failIf(servermap.recoverable_versions())
940             self.failIf(servermap.unrecoverable_versions())
941             self.failIf(servermap.all_peers())
942         d.addCallback(_check_servermap)
943         return d
944     test_no_servers.timeout = 15
945
946     def test_no_servers_download(self):
947         sb2 = make_storagebroker(num_peers=0)
948         self._fn._storage_broker = sb2
949         d = self.shouldFail(UnrecoverableFileError,
950                             "test_no_servers_download",
951                             "no recoverable versions",
952                             self._fn.download_best_version)
953         def _restore(res):
954             # a failed download that occurs while we aren't connected to
955             # anybody should not prevent a subsequent download from working.
956             # This isn't quite the webapi-driven test that #463 wants, but it
957             # should be close enough.
958             self._fn._storage_broker = self._storage_broker
959             return self._fn.download_best_version()
960         def _retrieved(new_contents):
961             self.failUnlessEqual(new_contents, self.CONTENTS)
962         d.addCallback(_restore)
963         d.addCallback(_retrieved)
964         return d
965     test_no_servers_download.timeout = 15
966
967     def _test_corrupt_all(self, offset, substring,
968                           should_succeed=False, corrupt_early=True,
969                           failure_checker=None):
970         d = defer.succeed(None)
971         if corrupt_early:
972             d.addCallback(corrupt, self._storage, offset)
973         d.addCallback(lambda res: self.make_servermap())
974         if not corrupt_early:
975             d.addCallback(corrupt, self._storage, offset)
976         def _do_retrieve(servermap):
977             ver = servermap.best_recoverable_version()
978             if ver is None and not should_succeed:
979                 # no recoverable versions == not succeeding. The problem
980                 # should be noted in the servermap's list of problems.
981                 if substring:
982                     allproblems = [str(f) for f in servermap.problems]
983                     self.failUnlessIn(substring, "".join(allproblems))
984                 return servermap
985             if should_succeed:
986                 d1 = self._fn.download_version(servermap, ver)
987                 d1.addCallback(lambda new_contents:
988                                self.failUnlessEqual(new_contents, self.CONTENTS))
989             else:
990                 d1 = self.shouldFail(NotEnoughSharesError,
991                                      "_corrupt_all(offset=%s)" % (offset,),
992                                      substring,
993                                      self._fn.download_version, servermap, ver)
994             if failure_checker:
995                 d1.addCallback(failure_checker)
996             d1.addCallback(lambda res: servermap)
997             return d1
998         d.addCallback(_do_retrieve)
999         return d
1000
1001     def test_corrupt_all_verbyte(self):
1002         # when the version byte is not 0, we hit an UnknownVersionError error
1003         # in unpack_share().
1004         d = self._test_corrupt_all(0, "UnknownVersionError")
1005         def _check_servermap(servermap):
1006             # and the dump should mention the problems
1007             s = StringIO()
1008             dump = servermap.dump(s).getvalue()
1009             self.failUnless("10 PROBLEMS" in dump, dump)
1010         d.addCallback(_check_servermap)
1011         return d
1012
1013     def test_corrupt_all_seqnum(self):
1014         # a corrupt sequence number will trigger a bad signature
1015         return self._test_corrupt_all(1, "signature is invalid")
1016
1017     def test_corrupt_all_R(self):
1018         # a corrupt root hash will trigger a bad signature
1019         return self._test_corrupt_all(9, "signature is invalid")
1020
1021     def test_corrupt_all_IV(self):
1022         # a corrupt salt/IV will trigger a bad signature
1023         return self._test_corrupt_all(41, "signature is invalid")
1024
1025     def test_corrupt_all_k(self):
1026         # a corrupt 'k' will trigger a bad signature
1027         return self._test_corrupt_all(57, "signature is invalid")
1028
1029     def test_corrupt_all_N(self):
1030         # a corrupt 'N' will trigger a bad signature
1031         return self._test_corrupt_all(58, "signature is invalid")
1032
1033     def test_corrupt_all_segsize(self):
1034         # a corrupt segsize will trigger a bad signature
1035         return self._test_corrupt_all(59, "signature is invalid")
1036
1037     def test_corrupt_all_datalen(self):
1038         # a corrupt data length will trigger a bad signature
1039         return self._test_corrupt_all(67, "signature is invalid")
1040
1041     def test_corrupt_all_pubkey(self):
1042         # a corrupt pubkey won't match the URI's fingerprint. We need to
1043         # remove the pubkey from the filenode, or else it won't bother trying
1044         # to update it.
1045         self._fn._pubkey = None
1046         return self._test_corrupt_all("pubkey",
1047                                       "pubkey doesn't match fingerprint")
1048
1049     def test_corrupt_all_sig(self):
1050         # a corrupt signature is a bad one
1051         # the signature runs from about [543:799], depending upon the length
1052         # of the pubkey
1053         return self._test_corrupt_all("signature", "signature is invalid")
1054
1055     def test_corrupt_all_share_hash_chain_number(self):
1056         # a corrupt share hash chain entry will show up as a bad hash. If we
1057         # mangle the first byte, that will look like a bad hash number,
1058         # causing an IndexError
1059         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1060
1061     def test_corrupt_all_share_hash_chain_hash(self):
1062         # a corrupt share hash chain entry will show up as a bad hash. If we
1063         # mangle a few bytes in, that will look like a bad hash.
1064         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1065
1066     def test_corrupt_all_block_hash_tree(self):
1067         return self._test_corrupt_all("block_hash_tree",
1068                                       "block hash tree failure")
1069
1070     def test_corrupt_all_block(self):
1071         return self._test_corrupt_all("share_data", "block hash tree failure")
1072
1073     def test_corrupt_all_encprivkey(self):
1074         # a corrupted privkey won't even be noticed by the reader, only by a
1075         # writer.
1076         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1077
1078
1079     def test_corrupt_all_seqnum_late(self):
1080         # corrupting the seqnum between mapupdate and retrieve should result
1081         # in NotEnoughSharesError, since each share will look invalid
1082         def _check(res):
1083             f = res[0]
1084             self.failUnless(f.check(NotEnoughSharesError))
1085             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1086         return self._test_corrupt_all(1, "ran out of peers",
1087                                       corrupt_early=False,
1088                                       failure_checker=_check)
1089
1090     def test_corrupt_all_block_hash_tree_late(self):
1091         def _check(res):
1092             f = res[0]
1093             self.failUnless(f.check(NotEnoughSharesError))
1094         return self._test_corrupt_all("block_hash_tree",
1095                                       "block hash tree failure",
1096                                       corrupt_early=False,
1097                                       failure_checker=_check)
1098
1099
1100     def test_corrupt_all_block_late(self):
1101         def _check(res):
1102             f = res[0]
1103             self.failUnless(f.check(NotEnoughSharesError))
1104         return self._test_corrupt_all("share_data", "block hash tree failure",
1105                                       corrupt_early=False,
1106                                       failure_checker=_check)
1107
1108
1109     def test_basic_pubkey_at_end(self):
1110         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1111         # download to succeed but forcing a bunch of retries first. Note that
1112         # this is rather pessimistic: our Retrieve process will throw away
1113         # the whole share if the pubkey is bad, even though the rest of the
1114         # share might be good.
1115
1116         self._fn._pubkey = None
1117         k = self._fn.get_required_shares()
1118         N = self._fn.get_total_shares()
1119         d = defer.succeed(None)
1120         d.addCallback(corrupt, self._storage, "pubkey",
1121                       shnums_to_corrupt=range(0, N-k))
1122         d.addCallback(lambda res: self.make_servermap())
1123         def _do_retrieve(servermap):
1124             self.failUnless(servermap.problems)
1125             self.failUnless("pubkey doesn't match fingerprint"
1126                             in str(servermap.problems[0]))
1127             ver = servermap.best_recoverable_version()
1128             r = Retrieve(self._fn, servermap, ver)
1129             return r.download()
1130         d.addCallback(_do_retrieve)
1131         d.addCallback(lambda new_contents:
1132                       self.failUnlessEqual(new_contents, self.CONTENTS))
1133         return d
1134
1135     def test_corrupt_some(self):
1136         # corrupt the data of first five shares (so the servermap thinks
1137         # they're good but retrieve marks them as bad), so that the
1138         # MODE_READ set of 6 will be insufficient, forcing node.download to
1139         # retry with more servers.
1140         corrupt(None, self._storage, "share_data", range(5))
1141         d = self.make_servermap()
1142         def _do_retrieve(servermap):
1143             ver = servermap.best_recoverable_version()
1144             self.failUnless(ver)
1145             return self._fn.download_best_version()
1146         d.addCallback(_do_retrieve)
1147         d.addCallback(lambda new_contents:
1148                       self.failUnlessEqual(new_contents, self.CONTENTS))
1149         return d
1150
1151     def test_download_fails(self):
1152         corrupt(None, self._storage, "signature")
1153         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1154                             "no recoverable versions",
1155                             self._fn.download_best_version)
1156         return d
1157
1158
1159 class CheckerMixin:
1160     def check_good(self, r, where):
1161         self.failUnless(r.is_healthy(), where)
1162         return r
1163
1164     def check_bad(self, r, where):
1165         self.failIf(r.is_healthy(), where)
1166         return r
1167
1168     def check_expected_failure(self, r, expected_exception, substring, where):
1169         for (peerid, storage_index, shnum, f) in r.problems:
1170             if f.check(expected_exception):
1171                 self.failUnless(substring in str(f),
1172                                 "%s: substring '%s' not in '%s'" %
1173                                 (where, substring, str(f)))
1174                 return
1175         self.fail("%s: didn't see expected exception %s in problems %s" %
1176                   (where, expected_exception, r.problems))
1177
1178
1179 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1180     def setUp(self):
1181         return self.publish_one()
1182
1183
1184     def test_check_good(self):
1185         d = self._fn.check(Monitor())
1186         d.addCallback(self.check_good, "test_check_good")
1187         return d
1188
1189     def test_check_no_shares(self):
1190         for shares in self._storage._peers.values():
1191             shares.clear()
1192         d = self._fn.check(Monitor())
1193         d.addCallback(self.check_bad, "test_check_no_shares")
1194         return d
1195
1196     def test_check_not_enough_shares(self):
1197         for shares in self._storage._peers.values():
1198             for shnum in shares.keys():
1199                 if shnum > 0:
1200                     del shares[shnum]
1201         d = self._fn.check(Monitor())
1202         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1203         return d
1204
1205     def test_check_all_bad_sig(self):
1206         corrupt(None, self._storage, 1) # bad sig
1207         d = self._fn.check(Monitor())
1208         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1209         return d
1210
1211     def test_check_all_bad_blocks(self):
1212         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1213         # the Checker won't notice this.. it doesn't look at actual data
1214         d = self._fn.check(Monitor())
1215         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1216         return d
1217
1218     def test_verify_good(self):
1219         d = self._fn.check(Monitor(), verify=True)
1220         d.addCallback(self.check_good, "test_verify_good")
1221         return d
1222
1223     def test_verify_all_bad_sig(self):
1224         corrupt(None, self._storage, 1) # bad sig
1225         d = self._fn.check(Monitor(), verify=True)
1226         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1227         return d
1228
1229     def test_verify_one_bad_sig(self):
1230         corrupt(None, self._storage, 1, [9]) # bad sig
1231         d = self._fn.check(Monitor(), verify=True)
1232         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1233         return d
1234
1235     def test_verify_one_bad_block(self):
1236         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1237         # the Verifier *will* notice this, since it examines every byte
1238         d = self._fn.check(Monitor(), verify=True)
1239         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1240         d.addCallback(self.check_expected_failure,
1241                       CorruptShareError, "block hash tree failure",
1242                       "test_verify_one_bad_block")
1243         return d
1244
1245     def test_verify_one_bad_sharehash(self):
1246         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1247         d = self._fn.check(Monitor(), verify=True)
1248         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1249         d.addCallback(self.check_expected_failure,
1250                       CorruptShareError, "corrupt hashes",
1251                       "test_verify_one_bad_sharehash")
1252         return d
1253
1254     def test_verify_one_bad_encprivkey(self):
1255         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1256         d = self._fn.check(Monitor(), verify=True)
1257         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1258         d.addCallback(self.check_expected_failure,
1259                       CorruptShareError, "invalid privkey",
1260                       "test_verify_one_bad_encprivkey")
1261         return d
1262
1263     def test_verify_one_bad_encprivkey_uncheckable(self):
1264         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1265         readonly_fn = self._fn.get_readonly()
1266         # a read-only node has no way to validate the privkey
1267         d = readonly_fn.check(Monitor(), verify=True)
1268         d.addCallback(self.check_good,
1269                       "test_verify_one_bad_encprivkey_uncheckable")
1270         return d
1271
1272 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1273
1274     def get_shares(self, s):
1275         all_shares = {} # maps (peerid, shnum) to share data
1276         for peerid in s._peers:
1277             shares = s._peers[peerid]
1278             for shnum in shares:
1279                 data = shares[shnum]
1280                 all_shares[ (peerid, shnum) ] = data
1281         return all_shares
1282
1283     def copy_shares(self, ignored=None):
1284         self.old_shares.append(self.get_shares(self._storage))
1285
1286     def test_repair_nop(self):
1287         self.old_shares = []
1288         d = self.publish_one()
1289         d.addCallback(self.copy_shares)
1290         d.addCallback(lambda res: self._fn.check(Monitor()))
1291         d.addCallback(lambda check_results: self._fn.repair(check_results))
1292         def _check_results(rres):
1293             self.failUnless(IRepairResults.providedBy(rres))
1294             self.failUnless(rres.get_successful())
1295             # TODO: examine results
1296
1297             self.copy_shares()
1298
1299             initial_shares = self.old_shares[0]
1300             new_shares = self.old_shares[1]
1301             # TODO: this really shouldn't change anything. When we implement
1302             # a "minimal-bandwidth" repairer", change this test to assert:
1303             #self.failUnlessEqual(new_shares, initial_shares)
1304
1305             # all shares should be in the same place as before
1306             self.failUnlessEqual(set(initial_shares.keys()),
1307                                  set(new_shares.keys()))
1308             # but they should all be at a newer seqnum. The IV will be
1309             # different, so the roothash will be too.
1310             for key in initial_shares:
1311                 (version0,
1312                  seqnum0,
1313                  root_hash0,
1314                  IV0,
1315                  k0, N0, segsize0, datalen0,
1316                  o0) = unpack_header(initial_shares[key])
1317                 (version1,
1318                  seqnum1,
1319                  root_hash1,
1320                  IV1,
1321                  k1, N1, segsize1, datalen1,
1322                  o1) = unpack_header(new_shares[key])
1323                 self.failUnlessEqual(version0, version1)
1324                 self.failUnlessEqual(seqnum0+1, seqnum1)
1325                 self.failUnlessEqual(k0, k1)
1326                 self.failUnlessEqual(N0, N1)
1327                 self.failUnlessEqual(segsize0, segsize1)
1328                 self.failUnlessEqual(datalen0, datalen1)
1329         d.addCallback(_check_results)
1330         return d
1331
1332     def failIfSharesChanged(self, ignored=None):
1333         old_shares = self.old_shares[-2]
1334         current_shares = self.old_shares[-1]
1335         self.failUnlessEqual(old_shares, current_shares)
1336
1337     def test_unrepairable_0shares(self):
1338         d = self.publish_one()
1339         def _delete_all_shares(ign):
1340             shares = self._storage._peers
1341             for peerid in shares:
1342                 shares[peerid] = {}
1343         d.addCallback(_delete_all_shares)
1344         d.addCallback(lambda ign: self._fn.check(Monitor()))
1345         d.addCallback(lambda check_results: self._fn.repair(check_results))
1346         def _check(crr):
1347             self.failUnlessEqual(crr.get_successful(), False)
1348         d.addCallback(_check)
1349         return d
1350
1351     def test_unrepairable_1share(self):
1352         d = self.publish_one()
1353         def _delete_all_shares(ign):
1354             shares = self._storage._peers
1355             for peerid in shares:
1356                 for shnum in list(shares[peerid]):
1357                     if shnum > 0:
1358                         del shares[peerid][shnum]
1359         d.addCallback(_delete_all_shares)
1360         d.addCallback(lambda ign: self._fn.check(Monitor()))
1361         d.addCallback(lambda check_results: self._fn.repair(check_results))
1362         def _check(crr):
1363             self.failUnlessEqual(crr.get_successful(), False)
1364         d.addCallback(_check)
1365         return d
1366
1367     def test_merge(self):
1368         self.old_shares = []
1369         d = self.publish_multiple()
1370         # repair will refuse to merge multiple highest seqnums unless you
1371         # pass force=True
1372         d.addCallback(lambda res:
1373                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1374                                           1:4,3:4,5:4,7:4,9:4}))
1375         d.addCallback(self.copy_shares)
1376         d.addCallback(lambda res: self._fn.check(Monitor()))
1377         def _try_repair(check_results):
1378             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1379             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1380                                  self._fn.repair, check_results)
1381             d2.addCallback(self.copy_shares)
1382             d2.addCallback(self.failIfSharesChanged)
1383             d2.addCallback(lambda res: check_results)
1384             return d2
1385         d.addCallback(_try_repair)
1386         d.addCallback(lambda check_results:
1387                       self._fn.repair(check_results, force=True))
1388         # this should give us 10 shares of the highest roothash
1389         def _check_repair_results(rres):
1390             self.failUnless(rres.get_successful())
1391             pass # TODO
1392         d.addCallback(_check_repair_results)
1393         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1394         def _check_smap(smap):
1395             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1396             self.failIf(smap.unrecoverable_versions())
1397             # now, which should have won?
1398             roothash_s4a = self.get_roothash_for(3)
1399             roothash_s4b = self.get_roothash_for(4)
1400             if roothash_s4b > roothash_s4a:
1401                 expected_contents = self.CONTENTS[4]
1402             else:
1403                 expected_contents = self.CONTENTS[3]
1404             new_versionid = smap.best_recoverable_version()
1405             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1406             d2 = self._fn.download_version(smap, new_versionid)
1407             d2.addCallback(self.failUnlessEqual, expected_contents)
1408             return d2
1409         d.addCallback(_check_smap)
1410         return d
1411
1412     def test_non_merge(self):
1413         self.old_shares = []
1414         d = self.publish_multiple()
1415         # repair should not refuse a repair that doesn't need to merge. In
1416         # this case, we combine v2 with v3. The repair should ignore v2 and
1417         # copy v3 into a new v5.
1418         d.addCallback(lambda res:
1419                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1420                                           1:3,3:3,5:3,7:3,9:3}))
1421         d.addCallback(lambda res: self._fn.check(Monitor()))
1422         d.addCallback(lambda check_results: self._fn.repair(check_results))
1423         # this should give us 10 shares of v3
1424         def _check_repair_results(rres):
1425             self.failUnless(rres.get_successful())
1426             pass # TODO
1427         d.addCallback(_check_repair_results)
1428         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1429         def _check_smap(smap):
1430             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1431             self.failIf(smap.unrecoverable_versions())
1432             # now, which should have won?
1433             expected_contents = self.CONTENTS[3]
1434             new_versionid = smap.best_recoverable_version()
1435             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1436             d2 = self._fn.download_version(smap, new_versionid)
1437             d2.addCallback(self.failUnlessEqual, expected_contents)
1438             return d2
1439         d.addCallback(_check_smap)
1440         return d
1441
1442     def get_roothash_for(self, index):
1443         # return the roothash for the first share we see in the saved set
1444         shares = self._copied_shares[index]
1445         for peerid in shares:
1446             for shnum in shares[peerid]:
1447                 share = shares[peerid][shnum]
1448                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1449                           unpack_header(share)
1450                 return root_hash
1451
1452     def test_check_and_repair_readcap(self):
1453         # we can't currently repair from a mutable readcap: #625
1454         self.old_shares = []
1455         d = self.publish_one()
1456         d.addCallback(self.copy_shares)
1457         def _get_readcap(res):
1458             self._fn3 = self._fn.get_readonly()
1459             # also delete some shares
1460             for peerid,shares in self._storage._peers.items():
1461                 shares.pop(0, None)
1462         d.addCallback(_get_readcap)
1463         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1464         def _check_results(crr):
1465             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1466             # we should detect the unhealthy, but skip over mutable-readcap
1467             # repairs until #625 is fixed
1468             self.failIf(crr.get_pre_repair_results().is_healthy())
1469             self.failIf(crr.get_repair_attempted())
1470             self.failIf(crr.get_post_repair_results().is_healthy())
1471         d.addCallback(_check_results)
1472         return d
1473
1474 class DevNullDictionary(dict):
1475     def __setitem__(self, key, value):
1476         return
1477
1478 class MultipleEncodings(unittest.TestCase):
1479     def setUp(self):
1480         self.CONTENTS = "New contents go here"
1481         self._storage = FakeStorage()
1482         self._nodemaker = make_nodemaker(self._storage, num_peers=20)
1483         self._storage_broker = self._nodemaker.storage_broker
1484         d = self._nodemaker.create_mutable_file(self.CONTENTS)
1485         def _created(node):
1486             self._fn = node
1487         d.addCallback(_created)
1488         return d
1489
1490     def _encode(self, k, n, data):
1491         # encode 'data' into a peerid->shares dict.
1492
1493         fn = self._fn
1494         # disable the nodecache, since for these tests we explicitly need
1495         # multiple nodes pointing at the same file
1496         self._nodemaker._node_cache = DevNullDictionary()
1497         fn2 = self._nodemaker.create_from_cap(fn.get_uri())
1498         # then we copy over other fields that are normally fetched from the
1499         # existing shares
1500         fn2._pubkey = fn._pubkey
1501         fn2._privkey = fn._privkey
1502         fn2._encprivkey = fn._encprivkey
1503         # and set the encoding parameters to something completely different
1504         fn2._required_shares = k
1505         fn2._total_shares = n
1506
1507         s = self._storage
1508         s._peers = {} # clear existing storage
1509         p2 = Publish(fn2, self._storage_broker, None)
1510         d = p2.publish(data)
1511         def _published(res):
1512             shares = s._peers
1513             s._peers = {}
1514             return shares
1515         d.addCallback(_published)
1516         return d
1517
1518     def make_servermap(self, mode=MODE_READ, oldmap=None):
1519         if oldmap is None:
1520             oldmap = ServerMap()
1521         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
1522                                oldmap, mode)
1523         d = smu.update()
1524         return d
1525
1526     def test_multiple_encodings(self):
1527         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1528         # then mix up the shares, to make sure that download survives seeing
1529         # a variety of encodings. This is actually kind of tricky to set up.
1530
1531         contents1 = "Contents for encoding 1 (3-of-10) go here"
1532         contents2 = "Contents for encoding 2 (4-of-9) go here"
1533         contents3 = "Contents for encoding 3 (4-of-7) go here"
1534
1535         # we make a retrieval object that doesn't know what encoding
1536         # parameters to use
1537         fn3 = self._nodemaker.create_from_cap(self._fn.get_uri())
1538
1539         # now we upload a file through fn1, and grab its shares
1540         d = self._encode(3, 10, contents1)
1541         def _encoded_1(shares):
1542             self._shares1 = shares
1543         d.addCallback(_encoded_1)
1544         d.addCallback(lambda res: self._encode(4, 9, contents2))
1545         def _encoded_2(shares):
1546             self._shares2 = shares
1547         d.addCallback(_encoded_2)
1548         d.addCallback(lambda res: self._encode(4, 7, contents3))
1549         def _encoded_3(shares):
1550             self._shares3 = shares
1551         d.addCallback(_encoded_3)
1552
1553         def _merge(res):
1554             log.msg("merging sharelists")
1555             # we merge the shares from the two sets, leaving each shnum in
1556             # its original location, but using a share from set1 or set2
1557             # according to the following sequence:
1558             #
1559             #  4-of-9  a  s2
1560             #  4-of-9  b  s2
1561             #  4-of-7  c   s3
1562             #  4-of-9  d  s2
1563             #  3-of-9  e s1
1564             #  3-of-9  f s1
1565             #  3-of-9  g s1
1566             #  4-of-9  h  s2
1567             #
1568             # so that neither form can be recovered until fetch [f], at which
1569             # point version-s1 (the 3-of-10 form) should be recoverable. If
1570             # the implementation latches on to the first version it sees,
1571             # then s2 will be recoverable at fetch [g].
1572
1573             # Later, when we implement code that handles multiple versions,
1574             # we can use this framework to assert that all recoverable
1575             # versions are retrieved, and test that 'epsilon' does its job
1576
1577             places = [2, 2, 3, 2, 1, 1, 1, 2]
1578
1579             sharemap = {}
1580             sb = self._storage_broker
1581
1582             for peerid in sorted(sb.get_all_serverids()):
1583                 for shnum in self._shares1.get(peerid, {}):
1584                     if shnum < len(places):
1585                         which = places[shnum]
1586                     else:
1587                         which = "x"
1588                     self._storage._peers[peerid] = peers = {}
1589                     in_1 = shnum in self._shares1[peerid]
1590                     in_2 = shnum in self._shares2.get(peerid, {})
1591                     in_3 = shnum in self._shares3.get(peerid, {})
1592                     if which == 1:
1593                         if in_1:
1594                             peers[shnum] = self._shares1[peerid][shnum]
1595                             sharemap[shnum] = peerid
1596                     elif which == 2:
1597                         if in_2:
1598                             peers[shnum] = self._shares2[peerid][shnum]
1599                             sharemap[shnum] = peerid
1600                     elif which == 3:
1601                         if in_3:
1602                             peers[shnum] = self._shares3[peerid][shnum]
1603                             sharemap[shnum] = peerid
1604
1605             # we don't bother placing any other shares
1606             # now sort the sequence so that share 0 is returned first
1607             new_sequence = [sharemap[shnum]
1608                             for shnum in sorted(sharemap.keys())]
1609             self._storage._sequence = new_sequence
1610             log.msg("merge done")
1611         d.addCallback(_merge)
1612         d.addCallback(lambda res: fn3.download_best_version())
1613         def _retrieved(new_contents):
1614             # the current specified behavior is "first version recoverable"
1615             self.failUnlessEqual(new_contents, contents1)
1616         d.addCallback(_retrieved)
1617         return d
1618
1619
1620 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1621
1622     def setUp(self):
1623         return self.publish_multiple()
1624
1625     def test_multiple_versions(self):
1626         # if we see a mix of versions in the grid, download_best_version
1627         # should get the latest one
1628         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1629         d = self._fn.download_best_version()
1630         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1631         # and the checker should report problems
1632         d.addCallback(lambda res: self._fn.check(Monitor()))
1633         d.addCallback(self.check_bad, "test_multiple_versions")
1634
1635         # but if everything is at version 2, that's what we should download
1636         d.addCallback(lambda res:
1637                       self._set_versions(dict([(i,2) for i in range(10)])))
1638         d.addCallback(lambda res: self._fn.download_best_version())
1639         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1640         # if exactly one share is at version 3, we should still get v2
1641         d.addCallback(lambda res:
1642                       self._set_versions({0:3}))
1643         d.addCallback(lambda res: self._fn.download_best_version())
1644         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1645         # but the servermap should see the unrecoverable version. This
1646         # depends upon the single newer share being queried early.
1647         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1648         def _check_smap(smap):
1649             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1650             newer = smap.unrecoverable_newer_versions()
1651             self.failUnlessEqual(len(newer), 1)
1652             verinfo, health = newer.items()[0]
1653             self.failUnlessEqual(verinfo[0], 4)
1654             self.failUnlessEqual(health, (1,3))
1655             self.failIf(smap.needs_merge())
1656         d.addCallback(_check_smap)
1657         # if we have a mix of two parallel versions (s4a and s4b), we could
1658         # recover either
1659         d.addCallback(lambda res:
1660                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1661                                           1:4,3:4,5:4,7:4,9:4}))
1662         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1663         def _check_smap_mixed(smap):
1664             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1665             newer = smap.unrecoverable_newer_versions()
1666             self.failUnlessEqual(len(newer), 0)
1667             self.failUnless(smap.needs_merge())
1668         d.addCallback(_check_smap_mixed)
1669         d.addCallback(lambda res: self._fn.download_best_version())
1670         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1671                                                   res == self.CONTENTS[4]))
1672         return d
1673
1674     def test_replace(self):
1675         # if we see a mix of versions in the grid, we should be able to
1676         # replace them all with a newer version
1677
1678         # if exactly one share is at version 3, we should download (and
1679         # replace) v2, and the result should be v4. Note that the index we
1680         # give to _set_versions is different than the sequence number.
1681         target = dict([(i,2) for i in range(10)]) # seqnum3
1682         target[0] = 3 # seqnum4
1683         self._set_versions(target)
1684
1685         def _modify(oldversion, servermap, first_time):
1686             return oldversion + " modified"
1687         d = self._fn.modify(_modify)
1688         d.addCallback(lambda res: self._fn.download_best_version())
1689         expected = self.CONTENTS[2] + " modified"
1690         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1691         # and the servermap should indicate that the outlier was replaced too
1692         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1693         def _check_smap(smap):
1694             self.failUnlessEqual(smap.highest_seqnum(), 5)
1695             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1696             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1697         d.addCallback(_check_smap)
1698         return d
1699
1700
1701 class Utils(unittest.TestCase):
1702     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1703         # we compare this against sets of integers
1704         x = set(range(x_start, x_start+x_length))
1705         y = set(range(y_start, y_start+y_length))
1706         should_be_inside = x.issubset(y)
1707         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1708                                                          y_start, y_length),
1709                              str((x_start, x_length, y_start, y_length)))
1710
1711     def test_cache_inside(self):
1712         c = ResponseCache()
1713         x_start = 10
1714         x_length = 5
1715         for y_start in range(8, 17):
1716             for y_length in range(8):
1717                 self._do_inside(c, x_start, x_length, y_start, y_length)
1718
1719     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1720         # we compare this against sets of integers
1721         x = set(range(x_start, x_start+x_length))
1722         y = set(range(y_start, y_start+y_length))
1723         overlap = bool(x.intersection(y))
1724         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1725                                                       y_start, y_length),
1726                              str((x_start, x_length, y_start, y_length)))
1727
1728     def test_cache_overlap(self):
1729         c = ResponseCache()
1730         x_start = 10
1731         x_length = 5
1732         for y_start in range(8, 17):
1733             for y_length in range(8):
1734                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1735
1736     def test_cache(self):
1737         c = ResponseCache()
1738         # xdata = base62.b2a(os.urandom(100))[:100]
1739         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1740         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1741         nope = (None, None)
1742         c.add("v1", 1, 0, xdata, "time0")
1743         c.add("v1", 1, 2000, ydata, "time1")
1744         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1745         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1747         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1748         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1749         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1750         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1752         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1753         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1754         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1755         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1756         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1757         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1758         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1759         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1760         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1761         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1762
1763         # optional: join fragments
1764         c = ResponseCache()
1765         c.add("v1", 1, 0, xdata[:10], "time0")
1766         c.add("v1", 1, 10, xdata[10:20], "time1")
1767         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1768
1769 class Exceptions(unittest.TestCase):
1770     def test_repr(self):
1771         nmde = NeedMoreDataError(100, 50, 100)
1772         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1773         ucwe = UncoordinatedWriteError()
1774         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1775
1776 class SameKeyGenerator:
1777     def __init__(self, pubkey, privkey):
1778         self.pubkey = pubkey
1779         self.privkey = privkey
1780     def generate(self, keysize=None):
1781         return defer.succeed( (self.pubkey, self.privkey) )
1782
1783 class FirstServerGetsKilled:
1784     done = False
1785     def notify(self, retval, wrapper, methname):
1786         if not self.done:
1787             wrapper.broken = True
1788             self.done = True
1789         return retval
1790
1791 class FirstServerGetsDeleted:
1792     def __init__(self):
1793         self.done = False
1794         self.silenced = None
1795     def notify(self, retval, wrapper, methname):
1796         if not self.done:
1797             # this query will work, but later queries should think the share
1798             # has been deleted
1799             self.done = True
1800             self.silenced = wrapper
1801             return retval
1802         if wrapper == self.silenced:
1803             assert methname == "slot_testv_and_readv_and_writev"
1804             return (True, {})
1805         return retval
1806
1807 class Problems(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
1808     def test_publish_surprise(self):
1809         self.basedir = "mutable/Problems/test_publish_surprise"
1810         self.set_up_grid()
1811         nm = self.g.clients[0].nodemaker
1812         d = nm.create_mutable_file("contents 1")
1813         def _created(n):
1814             d = defer.succeed(None)
1815             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1816             def _got_smap1(smap):
1817                 # stash the old state of the file
1818                 self.old_map = smap
1819             d.addCallback(_got_smap1)
1820             # then modify the file, leaving the old map untouched
1821             d.addCallback(lambda res: log.msg("starting winning write"))
1822             d.addCallback(lambda res: n.overwrite("contents 2"))
1823             # now attempt to modify the file with the old servermap. This
1824             # will look just like an uncoordinated write, in which every
1825             # single share got updated between our mapupdate and our publish
1826             d.addCallback(lambda res: log.msg("starting doomed write"))
1827             d.addCallback(lambda res:
1828                           self.shouldFail(UncoordinatedWriteError,
1829                                           "test_publish_surprise", None,
1830                                           n.upload,
1831                                           "contents 2a", self.old_map))
1832             return d
1833         d.addCallback(_created)
1834         return d
1835
1836     def test_retrieve_surprise(self):
1837         self.basedir = "mutable/Problems/test_retrieve_surprise"
1838         self.set_up_grid()
1839         nm = self.g.clients[0].nodemaker
1840         d = nm.create_mutable_file("contents 1")
1841         def _created(n):
1842             d = defer.succeed(None)
1843             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1844             def _got_smap1(smap):
1845                 # stash the old state of the file
1846                 self.old_map = smap
1847             d.addCallback(_got_smap1)
1848             # then modify the file, leaving the old map untouched
1849             d.addCallback(lambda res: log.msg("starting winning write"))
1850             d.addCallback(lambda res: n.overwrite("contents 2"))
1851             # now attempt to retrieve the old version with the old servermap.
1852             # This will look like someone has changed the file since we
1853             # updated the servermap.
1854             d.addCallback(lambda res: n._cache._clear())
1855             d.addCallback(lambda res: log.msg("starting doomed read"))
1856             d.addCallback(lambda res:
1857                           self.shouldFail(NotEnoughSharesError,
1858                                           "test_retrieve_surprise",
1859                                           "ran out of peers: have 0 shares (k=3)",
1860                                           n.download_version,
1861                                           self.old_map,
1862                                           self.old_map.best_recoverable_version(),
1863                                           ))
1864             return d
1865         d.addCallback(_created)
1866         return d
1867
1868     def test_unexpected_shares(self):
1869         # upload the file, take a servermap, shut down one of the servers,
1870         # upload it again (causing shares to appear on a new server), then
1871         # upload using the old servermap. The last upload should fail with an
1872         # UncoordinatedWriteError, because of the shares that didn't appear
1873         # in the servermap.
1874         self.basedir = "mutable/Problems/test_unexpected_shares"
1875         self.set_up_grid()
1876         nm = self.g.clients[0].nodemaker
1877         d = nm.create_mutable_file("contents 1")
1878         def _created(n):
1879             d = defer.succeed(None)
1880             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1881             def _got_smap1(smap):
1882                 # stash the old state of the file
1883                 self.old_map = smap
1884                 # now shut down one of the servers
1885                 peer0 = list(smap.make_sharemap()[0])[0]
1886                 self.g.remove_server(peer0)
1887                 # then modify the file, leaving the old map untouched
1888                 log.msg("starting winning write")
1889                 return n.overwrite("contents 2")
1890             d.addCallback(_got_smap1)
1891             # now attempt to modify the file with the old servermap. This
1892             # will look just like an uncoordinated write, in which every
1893             # single share got updated between our mapupdate and our publish
1894             d.addCallback(lambda res: log.msg("starting doomed write"))
1895             d.addCallback(lambda res:
1896                           self.shouldFail(UncoordinatedWriteError,
1897                                           "test_surprise", None,
1898                                           n.upload,
1899                                           "contents 2a", self.old_map))
1900             return d
1901         d.addCallback(_created)
1902         return d
1903
1904     def test_bad_server(self):
1905         # Break one server, then create the file: the initial publish should
1906         # complete with an alternate server. Breaking a second server should
1907         # not prevent an update from succeeding either.
1908         self.basedir = "mutable/Problems/test_bad_server"
1909         self.set_up_grid()
1910         nm = self.g.clients[0].nodemaker
1911
1912         # to make sure that one of the initial peers is broken, we have to
1913         # get creative. We create an RSA key and compute its storage-index.
1914         # Then we make a KeyGenerator that always returns that one key, and
1915         # use it to create the mutable file. This will get easier when we can
1916         # use #467 static-server-selection to disable permutation and force
1917         # the choice of server for share[0].
1918
1919         d = nm.key_generator.generate(522)
1920         def _got_key( (pubkey, privkey) ):
1921             nm.key_generator = SameKeyGenerator(pubkey, privkey)
1922             pubkey_s = pubkey.serialize()
1923             privkey_s = privkey.serialize()
1924             u = uri.WriteableSSKFileURI(ssk_writekey_hash(privkey_s),
1925                                         ssk_pubkey_fingerprint_hash(pubkey_s))
1926             self._storage_index = u.get_storage_index()
1927         d.addCallback(_got_key)
1928         def _break_peer0(res):
1929             si = self._storage_index
1930             peerlist = nm.storage_broker.get_servers_for_index(si)
1931             peerid0, connection0 = peerlist[0]
1932             peerid1, connection1 = peerlist[1]
1933             connection0.broken = True
1934             self.connection1 = connection1
1935         d.addCallback(_break_peer0)
1936         # now "create" the file, using the pre-established key, and let the
1937         # initial publish finally happen
1938         d.addCallback(lambda res: nm.create_mutable_file("contents 1"))
1939         # that ought to work
1940         def _got_node(n):
1941             d = n.download_best_version()
1942             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1943             # now break the second peer
1944             def _break_peer1(res):
1945                 self.connection1.broken = True
1946             d.addCallback(_break_peer1)
1947             d.addCallback(lambda res: n.overwrite("contents 2"))
1948             # that ought to work too
1949             d.addCallback(lambda res: n.download_best_version())
1950             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1951             def _explain_error(f):
1952                 print f
1953                 if f.check(NotEnoughServersError):
1954                     print "first_error:", f.value.first_error
1955                 return f
1956             d.addErrback(_explain_error)
1957             return d
1958         d.addCallback(_got_node)
1959         return d
1960
1961     def test_bad_server_overlap(self):
1962         # like test_bad_server, but with no extra unused servers to fall back
1963         # upon. This means that we must re-use a server which we've already
1964         # used. If we don't remember the fact that we sent them one share
1965         # already, we'll mistakenly think we're experiencing an
1966         # UncoordinatedWriteError.
1967
1968         # Break one server, then create the file: the initial publish should
1969         # complete with an alternate server. Breaking a second server should
1970         # not prevent an update from succeeding either.
1971         self.basedir = "mutable/Problems/test_bad_server_overlap"
1972         self.set_up_grid()
1973         nm = self.g.clients[0].nodemaker
1974         sb = nm.storage_broker
1975
1976         peerids = [serverid for (serverid,ss) in sb.get_all_servers()]
1977         self.g.break_server(peerids[0])
1978
1979         d = nm.create_mutable_file("contents 1")
1980         def _created(n):
1981             d = n.download_best_version()
1982             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1983             # now break one of the remaining servers
1984             def _break_second_server(res):
1985                 self.g.break_server(peerids[1])
1986             d.addCallback(_break_second_server)
1987             d.addCallback(lambda res: n.overwrite("contents 2"))
1988             # that ought to work too
1989             d.addCallback(lambda res: n.download_best_version())
1990             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1991             return d
1992         d.addCallback(_created)
1993         return d
1994
1995     def test_publish_all_servers_bad(self):
1996         # Break all servers: the publish should fail
1997         self.basedir = "mutable/Problems/test_publish_all_servers_bad"
1998         self.set_up_grid()
1999         nm = self.g.clients[0].nodemaker
2000         for (serverid,ss) in nm.storage_broker.get_all_servers():
2001             ss.broken = True
2002
2003         d = self.shouldFail(NotEnoughServersError,
2004                             "test_publish_all_servers_bad",
2005                             "Ran out of non-bad servers",
2006                             nm.create_mutable_file, "contents")
2007         return d
2008
2009     def test_publish_no_servers(self):
2010         # no servers at all: the publish should fail
2011         self.basedir = "mutable/Problems/test_publish_no_servers"
2012         self.set_up_grid(num_servers=0)
2013         nm = self.g.clients[0].nodemaker
2014
2015         d = self.shouldFail(NotEnoughServersError,
2016                             "test_publish_no_servers",
2017                             "Ran out of non-bad servers",
2018                             nm.create_mutable_file, "contents")
2019         return d
2020     test_publish_no_servers.timeout = 30
2021
2022
2023     def test_privkey_query_error(self):
2024         # when a servermap is updated with MODE_WRITE, it tries to get the
2025         # privkey. Something might go wrong during this query attempt.
2026         # Exercise the code in _privkey_query_failed which tries to handle
2027         # such an error.
2028         self.basedir = "mutable/Problems/test_privkey_query_error"
2029         self.set_up_grid(num_servers=20)
2030         nm = self.g.clients[0].nodemaker
2031         nm._node_cache = DevNullDictionary() # disable the nodecache
2032
2033         # we need some contents that are large enough to push the privkey out
2034         # of the early part of the file
2035         LARGE = "These are Larger contents" * 2000 # about 50KB
2036         d = nm.create_mutable_file(LARGE)
2037         def _created(n):
2038             self.uri = n.get_uri()
2039             self.n2 = nm.create_from_cap(self.uri)
2040
2041             # When a mapupdate is performed on a node that doesn't yet know
2042             # the privkey, a short read is sent to a batch of servers, to get
2043             # the verinfo and (hopefully, if the file is short enough) the
2044             # encprivkey. Our file is too large to let this first read
2045             # contain the encprivkey. Each non-encprivkey-bearing response
2046             # that arrives (until the node gets the encprivkey) will trigger
2047             # a second read to specifically read the encprivkey.
2048             #
2049             # So, to exercise this case:
2050             #  1. notice which server gets a read() call first
2051             #  2. tell that server to start throwing errors
2052             killer = FirstServerGetsKilled()
2053             for (serverid,ss) in nm.storage_broker.get_all_servers():
2054                 ss.post_call_notifier = killer.notify
2055         d.addCallback(_created)
2056
2057         # now we update a servermap from a new node (which doesn't have the
2058         # privkey yet, forcing it to use a separate privkey query). Note that
2059         # the map-update will succeed, since we'll just get a copy from one
2060         # of the other shares.
2061         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2062
2063         return d
2064
2065     def test_privkey_query_missing(self):
2066         # like test_privkey_query_error, but the shares are deleted by the
2067         # second query, instead of raising an exception.
2068         self.basedir = "mutable/Problems/test_privkey_query_missing"
2069         self.set_up_grid(num_servers=20)
2070         nm = self.g.clients[0].nodemaker
2071         LARGE = "These are Larger contents" * 2000 # about 50KB
2072         nm._node_cache = DevNullDictionary() # disable the nodecache
2073
2074         d = nm.create_mutable_file(LARGE)
2075         def _created(n):
2076             self.uri = n.get_uri()
2077             self.n2 = nm.create_from_cap(self.uri)
2078             deleter = FirstServerGetsDeleted()
2079             for (serverid,ss) in nm.storage_broker.get_all_servers():
2080                 ss.post_call_notifier = deleter.notify
2081         d.addCallback(_created)
2082         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2083         return d