]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
mutable repair: return successful=False when numshares<k (thus repair fails),
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from allmydata import uri, client
7 from allmydata.nodemaker import NodeMaker
8 from allmydata.util import base32
9 from allmydata.util.idlib import shortnodeid_b2a
10 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
11      ssk_pubkey_fingerprint_hash
12 from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
13      NotEnoughSharesError
14 from allmydata.monitor import Monitor
15 from allmydata.test.common import ShouldFailMixin
16 from allmydata.test.no_network import GridTestMixin
17 from foolscap.api import eventually, fireEventually
18 from foolscap.logging import log
19 from allmydata.storage_client import StorageFarmBroker
20
21 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
22 from allmydata.mutable.common import ResponseCache, \
23      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
24      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
25      NotEnoughServersError, CorruptShareError
26 from allmydata.mutable.retrieve import Retrieve
27 from allmydata.mutable.publish import Publish
28 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
29 from allmydata.mutable.layout import unpack_header, unpack_share
30 from allmydata.mutable.repairer import MustForceRepairError
31
32 import common_util as testutil
33
34 # this "FakeStorage" exists to put the share data in RAM and avoid using real
35 # network connections, both to speed up the tests and to reduce the amount of
36 # non-mutable.py code being exercised.
37
38 class FakeStorage:
39     # this class replaces the collection of storage servers, allowing the
40     # tests to examine and manipulate the published shares. It also lets us
41     # control the order in which read queries are answered, to exercise more
42     # of the error-handling code in Retrieve .
43     #
44     # Note that we ignore the storage index: this FakeStorage instance can
45     # only be used for a single storage index.
46
47
48     def __init__(self):
49         self._peers = {}
50         # _sequence is used to cause the responses to occur in a specific
51         # order. If it is in use, then we will defer queries instead of
52         # answering them right away, accumulating the Deferreds in a dict. We
53         # don't know exactly how many queries we'll get, so exactly one
54         # second after the first query arrives, we will release them all (in
55         # order).
56         self._sequence = None
57         self._pending = {}
58         self._pending_timer = None
59
60     def read(self, peerid, storage_index):
61         shares = self._peers.get(peerid, {})
62         if self._sequence is None:
63             return defer.succeed(shares)
64         d = defer.Deferred()
65         if not self._pending:
66             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
67         self._pending[peerid] = (d, shares)
68         return d
69
70     def _fire_readers(self):
71         self._pending_timer = None
72         pending = self._pending
73         self._pending = {}
74         extra = []
75         for peerid in self._sequence:
76             if peerid in pending:
77                 d, shares = pending.pop(peerid)
78                 eventually(d.callback, shares)
79         for (d, shares) in pending.values():
80             eventually(d.callback, shares)
81
82     def write(self, peerid, storage_index, shnum, offset, data):
83         if peerid not in self._peers:
84             self._peers[peerid] = {}
85         shares = self._peers[peerid]
86         f = StringIO()
87         f.write(shares.get(shnum, ""))
88         f.seek(offset)
89         f.write(data)
90         shares[shnum] = f.getvalue()
91
92
93 class FakeStorageServer:
94     def __init__(self, peerid, storage):
95         self.peerid = peerid
96         self.storage = storage
97         self.queries = 0
98     def callRemote(self, methname, *args, **kwargs):
99         def _call():
100             meth = getattr(self, methname)
101             return meth(*args, **kwargs)
102         d = fireEventually()
103         d.addCallback(lambda res: _call())
104         return d
105     def callRemoteOnly(self, methname, *args, **kwargs):
106         d = self.callRemote(methname, *args, **kwargs)
107         d.addBoth(lambda ignore: None)
108         pass
109
110     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
111         pass
112
113     def slot_readv(self, storage_index, shnums, readv):
114         d = self.storage.read(self.peerid, storage_index)
115         def _read(shares):
116             response = {}
117             for shnum in shares:
118                 if shnums and shnum not in shnums:
119                     continue
120                 vector = response[shnum] = []
121                 for (offset, length) in readv:
122                     assert isinstance(offset, (int, long)), offset
123                     assert isinstance(length, (int, long)), length
124                     vector.append(shares[shnum][offset:offset+length])
125             return response
126         d.addCallback(_read)
127         return d
128
129     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
130                                         tw_vectors, read_vector):
131         # always-pass: parrot the test vectors back to them.
132         readv = {}
133         for shnum, (testv, writev, new_length) in tw_vectors.items():
134             for (offset, length, op, specimen) in testv:
135                 assert op in ("le", "eq", "ge")
136             # TODO: this isn't right, the read is controlled by read_vector,
137             # not by testv
138             readv[shnum] = [ specimen
139                              for (offset, length, op, specimen)
140                              in testv ]
141             for (offset, data) in writev:
142                 self.storage.write(self.peerid, storage_index, shnum,
143                                    offset, data)
144         answer = (True, readv)
145         return fireEventually(answer)
146
147
148 def flip_bit(original, byte_offset):
149     return (original[:byte_offset] +
150             chr(ord(original[byte_offset]) ^ 0x01) +
151             original[byte_offset+1:])
152
153 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
154     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
155     # list of shnums to corrupt.
156     for peerid in s._peers:
157         shares = s._peers[peerid]
158         for shnum in shares:
159             if (shnums_to_corrupt is not None
160                 and shnum not in shnums_to_corrupt):
161                 continue
162             data = shares[shnum]
163             (version,
164              seqnum,
165              root_hash,
166              IV,
167              k, N, segsize, datalen,
168              o) = unpack_header(data)
169             if isinstance(offset, tuple):
170                 offset1, offset2 = offset
171             else:
172                 offset1 = offset
173                 offset2 = 0
174             if offset1 == "pubkey":
175                 real_offset = 107
176             elif offset1 in o:
177                 real_offset = o[offset1]
178             else:
179                 real_offset = offset1
180             real_offset = int(real_offset) + offset2 + offset_offset
181             assert isinstance(real_offset, int), offset
182             shares[shnum] = flip_bit(data, real_offset)
183     return res
184
185 def make_storagebroker(s=None, num_peers=10):
186     if not s:
187         s = FakeStorage()
188     peerids = [tagged_hash("peerid", "%d" % i)[:20]
189                for i in range(num_peers)]
190     storage_broker = StorageFarmBroker(None, True)
191     for peerid in peerids:
192         fss = FakeStorageServer(peerid, s)
193         storage_broker.test_add_server(peerid, fss)
194     return storage_broker
195
196 def make_nodemaker(s=None, num_peers=10):
197     storage_broker = make_storagebroker(s, num_peers)
198     sh = client.SecretHolder("lease secret", "convergence secret")
199     keygen = client.KeyGenerator()
200     keygen.set_default_keysize(522)
201     nodemaker = NodeMaker(storage_broker, sh, None,
202                           None, None, None,
203                           {"k": 3, "n": 10}, keygen)
204     return nodemaker
205
206 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
207     # this used to be in Publish, but we removed the limit. Some of
208     # these tests test whether the new code correctly allows files
209     # larger than the limit.
210     OLD_MAX_SEGMENT_SIZE = 3500000
211     def setUp(self):
212         self._storage = s = FakeStorage()
213         self.nodemaker = make_nodemaker(s)
214
215     def test_create(self):
216         d = self.nodemaker.create_mutable_file()
217         def _created(n):
218             self.failUnless(isinstance(n, MutableFileNode))
219             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
220             sb = self.nodemaker.storage_broker
221             peer0 = sorted(sb.get_all_serverids())[0]
222             shnums = self._storage._peers[peer0].keys()
223             self.failUnlessEqual(len(shnums), 1)
224         d.addCallback(_created)
225         return d
226
227     def test_serialize(self):
228         n = MutableFileNode(None, None, {"k": 3, "n": 10}, None)
229         calls = []
230         def _callback(*args, **kwargs):
231             self.failUnlessEqual(args, (4,) )
232             self.failUnlessEqual(kwargs, {"foo": 5})
233             calls.append(1)
234             return 6
235         d = n._do_serialized(_callback, 4, foo=5)
236         def _check_callback(res):
237             self.failUnlessEqual(res, 6)
238             self.failUnlessEqual(calls, [1])
239         d.addCallback(_check_callback)
240
241         def _errback():
242             raise ValueError("heya")
243         d.addCallback(lambda res:
244                       self.shouldFail(ValueError, "_check_errback", "heya",
245                                       n._do_serialized, _errback))
246         return d
247
248     def test_upload_and_download(self):
249         d = self.nodemaker.create_mutable_file()
250         def _created(n):
251             d = defer.succeed(None)
252             d.addCallback(lambda res: n.get_servermap(MODE_READ))
253             d.addCallback(lambda smap: smap.dump(StringIO()))
254             d.addCallback(lambda sio:
255                           self.failUnless("3-of-10" in sio.getvalue()))
256             d.addCallback(lambda res: n.overwrite("contents 1"))
257             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
258             d.addCallback(lambda res: n.download_best_version())
259             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
260             d.addCallback(lambda res: n.get_size_of_best_version())
261             d.addCallback(lambda size:
262                           self.failUnlessEqual(size, len("contents 1")))
263             d.addCallback(lambda res: n.overwrite("contents 2"))
264             d.addCallback(lambda res: n.download_best_version())
265             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
266             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
267             d.addCallback(lambda smap: n.upload("contents 3", smap))
268             d.addCallback(lambda res: n.download_best_version())
269             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
270             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
271             d.addCallback(lambda smap:
272                           n.download_version(smap,
273                                              smap.best_recoverable_version()))
274             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
275             # test a file that is large enough to overcome the
276             # mapupdate-to-retrieve data caching (i.e. make the shares larger
277             # than the default readsize, which is 2000 bytes). A 15kB file
278             # will have 5kB shares.
279             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
280             d.addCallback(lambda res: n.download_best_version())
281             d.addCallback(lambda res:
282                           self.failUnlessEqual(res, "large size file" * 1000))
283             return d
284         d.addCallback(_created)
285         return d
286
287     def test_create_with_initial_contents(self):
288         d = self.nodemaker.create_mutable_file("contents 1")
289         def _created(n):
290             d = n.download_best_version()
291             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
292             d.addCallback(lambda res: n.overwrite("contents 2"))
293             d.addCallback(lambda res: n.download_best_version())
294             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
295             return d
296         d.addCallback(_created)
297         return d
298
299     def test_create_with_initial_contents_function(self):
300         data = "initial contents"
301         def _make_contents(n):
302             self.failUnless(isinstance(n, MutableFileNode))
303             key = n.get_writekey()
304             self.failUnless(isinstance(key, str), key)
305             self.failUnlessEqual(len(key), 16) # AES key size
306             return data
307         d = self.nodemaker.create_mutable_file(_make_contents)
308         def _created(n):
309             return n.download_best_version()
310         d.addCallback(_created)
311         d.addCallback(lambda data2: self.failUnlessEqual(data2, data))
312         return d
313
314     def test_create_with_too_large_contents(self):
315         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
316         d = self.nodemaker.create_mutable_file(BIG)
317         def _created(n):
318             d = n.overwrite(BIG)
319             return d
320         d.addCallback(_created)
321         return d
322
323     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
324         d = n.get_servermap(MODE_READ)
325         d.addCallback(lambda servermap: servermap.best_recoverable_version())
326         d.addCallback(lambda verinfo:
327                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
328         return d
329
330     def test_modify(self):
331         def _modifier(old_contents, servermap, first_time):
332             return old_contents + "line2"
333         def _non_modifier(old_contents, servermap, first_time):
334             return old_contents
335         def _none_modifier(old_contents, servermap, first_time):
336             return None
337         def _error_modifier(old_contents, servermap, first_time):
338             raise ValueError("oops")
339         def _toobig_modifier(old_contents, servermap, first_time):
340             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
341         calls = []
342         def _ucw_error_modifier(old_contents, servermap, first_time):
343             # simulate an UncoordinatedWriteError once
344             calls.append(1)
345             if len(calls) <= 1:
346                 raise UncoordinatedWriteError("simulated")
347             return old_contents + "line3"
348         def _ucw_error_non_modifier(old_contents, servermap, first_time):
349             # simulate an UncoordinatedWriteError once, and don't actually
350             # modify the contents on subsequent invocations
351             calls.append(1)
352             if len(calls) <= 1:
353                 raise UncoordinatedWriteError("simulated")
354             return old_contents
355
356         d = self.nodemaker.create_mutable_file("line1")
357         def _created(n):
358             d = n.modify(_modifier)
359             d.addCallback(lambda res: n.download_best_version())
360             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
361             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
362
363             d.addCallback(lambda res: n.modify(_non_modifier))
364             d.addCallback(lambda res: n.download_best_version())
365             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
366             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
367
368             d.addCallback(lambda res: n.modify(_none_modifier))
369             d.addCallback(lambda res: n.download_best_version())
370             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
371             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
372
373             d.addCallback(lambda res:
374                           self.shouldFail(ValueError, "error_modifier", None,
375                                           n.modify, _error_modifier))
376             d.addCallback(lambda res: n.download_best_version())
377             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
378             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
379
380
381             d.addCallback(lambda res: n.download_best_version())
382             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
383             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
384
385             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
386             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
387             d.addCallback(lambda res: n.download_best_version())
388             d.addCallback(lambda res: self.failUnlessEqual(res,
389                                                            "line1line2line3"))
390             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
391
392             def _reset_ucw_error_modifier(res):
393                 calls[:] = []
394                 return res
395             d.addCallback(_reset_ucw_error_modifier)
396
397             # in practice, this n.modify call should publish twice: the first
398             # one gets a UCWE, the second does not. But our test jig (in
399             # which the modifier raises the UCWE) skips over the first one,
400             # so in this test there will be only one publish, and the seqnum
401             # will only be one larger than the previous test, not two (i.e. 4
402             # instead of 5).
403             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
404             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
405             d.addCallback(lambda res: n.download_best_version())
406             d.addCallback(lambda res: self.failUnlessEqual(res,
407                                                            "line1line2line3"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
409             d.addCallback(lambda res: n.modify(_toobig_modifier))
410             return d
411         d.addCallback(_created)
412         return d
413
414     def test_modify_backoffer(self):
415         def _modifier(old_contents, servermap, first_time):
416             return old_contents + "line2"
417         calls = []
418         def _ucw_error_modifier(old_contents, servermap, first_time):
419             # simulate an UncoordinatedWriteError once
420             calls.append(1)
421             if len(calls) <= 1:
422                 raise UncoordinatedWriteError("simulated")
423             return old_contents + "line3"
424         def _always_ucw_error_modifier(old_contents, servermap, first_time):
425             raise UncoordinatedWriteError("simulated")
426         def _backoff_stopper(node, f):
427             return f
428         def _backoff_pauser(node, f):
429             d = defer.Deferred()
430             reactor.callLater(0.5, d.callback, None)
431             return d
432
433         # the give-up-er will hit its maximum retry count quickly
434         giveuper = BackoffAgent()
435         giveuper._delay = 0.1
436         giveuper.factor = 1
437
438         d = self.nodemaker.create_mutable_file("line1")
439         def _created(n):
440             d = n.modify(_modifier)
441             d.addCallback(lambda res: n.download_best_version())
442             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
443             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
444
445             d.addCallback(lambda res:
446                           self.shouldFail(UncoordinatedWriteError,
447                                           "_backoff_stopper", None,
448                                           n.modify, _ucw_error_modifier,
449                                           _backoff_stopper))
450             d.addCallback(lambda res: n.download_best_version())
451             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
452             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
453
454             def _reset_ucw_error_modifier(res):
455                 calls[:] = []
456                 return res
457             d.addCallback(_reset_ucw_error_modifier)
458             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
459                                                _backoff_pauser))
460             d.addCallback(lambda res: n.download_best_version())
461             d.addCallback(lambda res: self.failUnlessEqual(res,
462                                                            "line1line2line3"))
463             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
464
465             d.addCallback(lambda res:
466                           self.shouldFail(UncoordinatedWriteError,
467                                           "giveuper", None,
468                                           n.modify, _always_ucw_error_modifier,
469                                           giveuper.delay))
470             d.addCallback(lambda res: n.download_best_version())
471             d.addCallback(lambda res: self.failUnlessEqual(res,
472                                                            "line1line2line3"))
473             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
474
475             return d
476         d.addCallback(_created)
477         return d
478
479     def test_upload_and_download_full_size_keys(self):
480         self.nodemaker.key_generator = client.KeyGenerator()
481         d = self.nodemaker.create_mutable_file()
482         def _created(n):
483             d = defer.succeed(None)
484             d.addCallback(lambda res: n.get_servermap(MODE_READ))
485             d.addCallback(lambda smap: smap.dump(StringIO()))
486             d.addCallback(lambda sio:
487                           self.failUnless("3-of-10" in sio.getvalue()))
488             d.addCallback(lambda res: n.overwrite("contents 1"))
489             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
490             d.addCallback(lambda res: n.download_best_version())
491             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
492             d.addCallback(lambda res: n.overwrite("contents 2"))
493             d.addCallback(lambda res: n.download_best_version())
494             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
495             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
496             d.addCallback(lambda smap: n.upload("contents 3", smap))
497             d.addCallback(lambda res: n.download_best_version())
498             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
499             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
500             d.addCallback(lambda smap:
501                           n.download_version(smap,
502                                              smap.best_recoverable_version()))
503             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
504             return d
505         d.addCallback(_created)
506         return d
507
508
509 class MakeShares(unittest.TestCase):
510     def test_encrypt(self):
511         nm = make_nodemaker()
512         CONTENTS = "some initial contents"
513         d = nm.create_mutable_file(CONTENTS)
514         def _created(fn):
515             p = Publish(fn, nm.storage_broker, None)
516             p.salt = "SALT" * 4
517             p.readkey = "\x00" * 16
518             p.newdata = CONTENTS
519             p.required_shares = 3
520             p.total_shares = 10
521             p.setup_encoding_parameters()
522             return p._encrypt_and_encode()
523         d.addCallback(_created)
524         def _done(shares_and_shareids):
525             (shares, share_ids) = shares_and_shareids
526             self.failUnlessEqual(len(shares), 10)
527             for sh in shares:
528                 self.failUnless(isinstance(sh, str))
529                 self.failUnlessEqual(len(sh), 7)
530             self.failUnlessEqual(len(share_ids), 10)
531         d.addCallback(_done)
532         return d
533
534     def test_generate(self):
535         nm = make_nodemaker()
536         CONTENTS = "some initial contents"
537         d = nm.create_mutable_file(CONTENTS)
538         def _created(fn):
539             self._fn = fn
540             p = Publish(fn, nm.storage_broker, None)
541             self._p = p
542             p.newdata = CONTENTS
543             p.required_shares = 3
544             p.total_shares = 10
545             p.setup_encoding_parameters()
546             p._new_seqnum = 3
547             p.salt = "SALT" * 4
548             # make some fake shares
549             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
550             p._privkey = fn.get_privkey()
551             p._encprivkey = fn.get_encprivkey()
552             p._pubkey = fn.get_pubkey()
553             return p._generate_shares(shares_and_ids)
554         d.addCallback(_created)
555         def _generated(res):
556             p = self._p
557             final_shares = p.shares
558             root_hash = p.root_hash
559             self.failUnlessEqual(len(root_hash), 32)
560             self.failUnless(isinstance(final_shares, dict))
561             self.failUnlessEqual(len(final_shares), 10)
562             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
563             for i,sh in final_shares.items():
564                 self.failUnless(isinstance(sh, str))
565                 # feed the share through the unpacker as a sanity-check
566                 pieces = unpack_share(sh)
567                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
568                  pubkey, signature, share_hash_chain, block_hash_tree,
569                  share_data, enc_privkey) = pieces
570                 self.failUnlessEqual(u_seqnum, 3)
571                 self.failUnlessEqual(u_root_hash, root_hash)
572                 self.failUnlessEqual(k, 3)
573                 self.failUnlessEqual(N, 10)
574                 self.failUnlessEqual(segsize, 21)
575                 self.failUnlessEqual(datalen, len(CONTENTS))
576                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
577                 sig_material = struct.pack(">BQ32s16s BBQQ",
578                                            0, p._new_seqnum, root_hash, IV,
579                                            k, N, segsize, datalen)
580                 self.failUnless(p._pubkey.verify(sig_material, signature))
581                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
582                 self.failUnless(isinstance(share_hash_chain, dict))
583                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
584                 for shnum,share_hash in share_hash_chain.items():
585                     self.failUnless(isinstance(shnum, int))
586                     self.failUnless(isinstance(share_hash, str))
587                     self.failUnlessEqual(len(share_hash), 32)
588                 self.failUnless(isinstance(block_hash_tree, list))
589                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
590                 self.failUnlessEqual(IV, "SALT"*4)
591                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
592                 self.failUnlessEqual(enc_privkey, self._fn.get_encprivkey())
593         d.addCallback(_generated)
594         return d
595
596     # TODO: when we publish to 20 peers, we should get one share per peer on 10
597     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
598     # when we publish to zero peers, we should get a NotEnoughSharesError
599
600 class PublishMixin:
601     def publish_one(self):
602         # publish a file and create shares, which can then be manipulated
603         # later.
604         self.CONTENTS = "New contents go here" * 1000
605         num_peers = 20
606         self._storage = FakeStorage()
607         self._nodemaker = make_nodemaker(self._storage)
608         self._storage_broker = self._nodemaker.storage_broker
609         d = self._nodemaker.create_mutable_file(self.CONTENTS)
610         def _created(node):
611             self._fn = node
612             self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
613         d.addCallback(_created)
614         return d
615
616     def publish_multiple(self):
617         self.CONTENTS = ["Contents 0",
618                          "Contents 1",
619                          "Contents 2",
620                          "Contents 3a",
621                          "Contents 3b"]
622         self._copied_shares = {}
623         num_peers = 20
624         self._storage = FakeStorage()
625         self._nodemaker = make_nodemaker(self._storage)
626         d = self._nodemaker.create_mutable_file(self.CONTENTS[0]) # seqnum=1
627         def _created(node):
628             self._fn = node
629             # now create multiple versions of the same file, and accumulate
630             # their shares, so we can mix and match them later.
631             d = defer.succeed(None)
632             d.addCallback(self._copy_shares, 0)
633             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
634             d.addCallback(self._copy_shares, 1)
635             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
636             d.addCallback(self._copy_shares, 2)
637             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
638             d.addCallback(self._copy_shares, 3)
639             # now we replace all the shares with version s3, and upload a new
640             # version to get s4b.
641             rollback = dict([(i,2) for i in range(10)])
642             d.addCallback(lambda res: self._set_versions(rollback))
643             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
644             d.addCallback(self._copy_shares, 4)
645             # we leave the storage in state 4
646             return d
647         d.addCallback(_created)
648         return d
649
650     def _copy_shares(self, ignored, index):
651         shares = self._storage._peers
652         # we need a deep copy
653         new_shares = {}
654         for peerid in shares:
655             new_shares[peerid] = {}
656             for shnum in shares[peerid]:
657                 new_shares[peerid][shnum] = shares[peerid][shnum]
658         self._copied_shares[index] = new_shares
659
660     def _set_versions(self, versionmap):
661         # versionmap maps shnums to which version (0,1,2,3,4) we want the
662         # share to be at. Any shnum which is left out of the map will stay at
663         # its current version.
664         shares = self._storage._peers
665         oldshares = self._copied_shares
666         for peerid in shares:
667             for shnum in shares[peerid]:
668                 if shnum in versionmap:
669                     index = versionmap[shnum]
670                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
671
672
673 class Servermap(unittest.TestCase, PublishMixin):
674     def setUp(self):
675         return self.publish_one()
676
677     def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None):
678         if fn is None:
679             fn = self._fn
680         if sb is None:
681             sb = self._storage_broker
682         smu = ServermapUpdater(fn, sb, Monitor(),
683                                ServerMap(), mode)
684         d = smu.update()
685         return d
686
687     def update_servermap(self, oldmap, mode=MODE_CHECK):
688         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
689                                oldmap, mode)
690         d = smu.update()
691         return d
692
693     def failUnlessOneRecoverable(self, sm, num_shares):
694         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
695         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
696         best = sm.best_recoverable_version()
697         self.failIfEqual(best, None)
698         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
699         self.failUnlessEqual(len(sm.shares_available()), 1)
700         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
701         shnum, peerids = sm.make_sharemap().items()[0]
702         peerid = list(peerids)[0]
703         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
704         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
705         return sm
706
707     def test_basic(self):
708         d = defer.succeed(None)
709         ms = self.make_servermap
710         us = self.update_servermap
711
712         d.addCallback(lambda res: ms(mode=MODE_CHECK))
713         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
714         d.addCallback(lambda res: ms(mode=MODE_WRITE))
715         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
716         d.addCallback(lambda res: ms(mode=MODE_READ))
717         # this mode stops at k+epsilon, and epsilon=k, so 6 shares
718         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
719         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
720         # this mode stops at 'k' shares
721         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
722
723         # and can we re-use the same servermap? Note that these are sorted in
724         # increasing order of number of servers queried, since once a server
725         # gets into the servermap, we'll always ask it for an update.
726         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
727         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
728         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
729         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
730         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
731         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
732         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
733         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
734         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
735
736         return d
737
738     def test_fetch_privkey(self):
739         d = defer.succeed(None)
740         # use the sibling filenode (which hasn't been used yet), and make
741         # sure it can fetch the privkey. The file is small, so the privkey
742         # will be fetched on the first (query) pass.
743         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
744         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
745
746         # create a new file, which is large enough to knock the privkey out
747         # of the early part of the file
748         LARGE = "These are Larger contents" * 200 # about 5KB
749         d.addCallback(lambda res: self._nodemaker.create_mutable_file(LARGE))
750         def _created(large_fn):
751             large_fn2 = self._nodemaker.create_from_cap(large_fn.get_uri())
752             return self.make_servermap(MODE_WRITE, large_fn2)
753         d.addCallback(_created)
754         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
755         return d
756
757     def test_mark_bad(self):
758         d = defer.succeed(None)
759         ms = self.make_servermap
760         us = self.update_servermap
761
762         d.addCallback(lambda res: ms(mode=MODE_READ))
763         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
764         def _made_map(sm):
765             v = sm.best_recoverable_version()
766             vm = sm.make_versionmap()
767             shares = list(vm[v])
768             self.failUnlessEqual(len(shares), 6)
769             self._corrupted = set()
770             # mark the first 5 shares as corrupt, then update the servermap.
771             # The map should not have the marked shares it in any more, and
772             # new shares should be found to replace the missing ones.
773             for (shnum, peerid, timestamp) in shares:
774                 if shnum < 5:
775                     self._corrupted.add( (peerid, shnum) )
776                     sm.mark_bad_share(peerid, shnum, "")
777             return self.update_servermap(sm, MODE_WRITE)
778         d.addCallback(_made_map)
779         def _check_map(sm):
780             # this should find all 5 shares that weren't marked bad
781             v = sm.best_recoverable_version()
782             vm = sm.make_versionmap()
783             shares = list(vm[v])
784             for (peerid, shnum) in self._corrupted:
785                 peer_shares = sm.shares_on_peer(peerid)
786                 self.failIf(shnum in peer_shares,
787                             "%d was in %s" % (shnum, peer_shares))
788             self.failUnlessEqual(len(shares), 5)
789         d.addCallback(_check_map)
790         return d
791
792     def failUnlessNoneRecoverable(self, sm):
793         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
794         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
795         best = sm.best_recoverable_version()
796         self.failUnlessEqual(best, None)
797         self.failUnlessEqual(len(sm.shares_available()), 0)
798
799     def test_no_shares(self):
800         self._storage._peers = {} # delete all shares
801         ms = self.make_servermap
802         d = defer.succeed(None)
803
804         d.addCallback(lambda res: ms(mode=MODE_CHECK))
805         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
806
807         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
808         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
809
810         d.addCallback(lambda res: ms(mode=MODE_WRITE))
811         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
812
813         d.addCallback(lambda res: ms(mode=MODE_READ))
814         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
815
816         return d
817
818     def failUnlessNotQuiteEnough(self, sm):
819         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
820         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
821         best = sm.best_recoverable_version()
822         self.failUnlessEqual(best, None)
823         self.failUnlessEqual(len(sm.shares_available()), 1)
824         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
825         return sm
826
827     def test_not_quite_enough_shares(self):
828         s = self._storage
829         ms = self.make_servermap
830         num_shares = len(s._peers)
831         for peerid in s._peers:
832             s._peers[peerid] = {}
833             num_shares -= 1
834             if num_shares == 2:
835                 break
836         # now there ought to be only two shares left
837         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
838
839         d = defer.succeed(None)
840
841         d.addCallback(lambda res: ms(mode=MODE_CHECK))
842         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
843         d.addCallback(lambda sm:
844                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
845         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
846         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
847         d.addCallback(lambda res: ms(mode=MODE_WRITE))
848         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
849         d.addCallback(lambda res: ms(mode=MODE_READ))
850         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
851
852         return d
853
854
855
856 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
857     def setUp(self):
858         return self.publish_one()
859
860     def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None):
861         if oldmap is None:
862             oldmap = ServerMap()
863         if sb is None:
864             sb = self._storage_broker
865         smu = ServermapUpdater(self._fn, sb, Monitor(), oldmap, mode)
866         d = smu.update()
867         return d
868
869     def abbrev_verinfo(self, verinfo):
870         if verinfo is None:
871             return None
872         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
873          offsets_tuple) = verinfo
874         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
875
876     def abbrev_verinfo_dict(self, verinfo_d):
877         output = {}
878         for verinfo,value in verinfo_d.items():
879             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
880              offsets_tuple) = verinfo
881             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
882         return output
883
884     def dump_servermap(self, servermap):
885         print "SERVERMAP", servermap
886         print "RECOVERABLE", [self.abbrev_verinfo(v)
887                               for v in servermap.recoverable_versions()]
888         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
889         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
890
891     def do_download(self, servermap, version=None):
892         if version is None:
893             version = servermap.best_recoverable_version()
894         r = Retrieve(self._fn, servermap, version)
895         return r.download()
896
897     def test_basic(self):
898         d = self.make_servermap()
899         def _do_retrieve(servermap):
900             self._smap = servermap
901             #self.dump_servermap(servermap)
902             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
903             return self.do_download(servermap)
904         d.addCallback(_do_retrieve)
905         def _retrieved(new_contents):
906             self.failUnlessEqual(new_contents, self.CONTENTS)
907         d.addCallback(_retrieved)
908         # we should be able to re-use the same servermap, both with and
909         # without updating it.
910         d.addCallback(lambda res: self.do_download(self._smap))
911         d.addCallback(_retrieved)
912         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
913         d.addCallback(lambda res: self.do_download(self._smap))
914         d.addCallback(_retrieved)
915         # clobbering the pubkey should make the servermap updater re-fetch it
916         def _clobber_pubkey(res):
917             self._fn._pubkey = None
918         d.addCallback(_clobber_pubkey)
919         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
920         d.addCallback(lambda res: self.do_download(self._smap))
921         d.addCallback(_retrieved)
922         return d
923
924     def test_all_shares_vanished(self):
925         d = self.make_servermap()
926         def _remove_shares(servermap):
927             for shares in self._storage._peers.values():
928                 shares.clear()
929             d1 = self.shouldFail(NotEnoughSharesError,
930                                  "test_all_shares_vanished",
931                                  "ran out of peers",
932                                  self.do_download, servermap)
933             return d1
934         d.addCallback(_remove_shares)
935         return d
936
937     def test_no_servers(self):
938         sb2 = make_storagebroker(num_peers=0)
939         # if there are no servers, then a MODE_READ servermap should come
940         # back empty
941         d = self.make_servermap(sb=sb2)
942         def _check_servermap(servermap):
943             self.failUnlessEqual(servermap.best_recoverable_version(), None)
944             self.failIf(servermap.recoverable_versions())
945             self.failIf(servermap.unrecoverable_versions())
946             self.failIf(servermap.all_peers())
947         d.addCallback(_check_servermap)
948         return d
949     test_no_servers.timeout = 15
950
951     def test_no_servers_download(self):
952         sb2 = make_storagebroker(num_peers=0)
953         self._fn._storage_broker = sb2
954         d = self.shouldFail(UnrecoverableFileError,
955                             "test_no_servers_download",
956                             "no recoverable versions",
957                             self._fn.download_best_version)
958         def _restore(res):
959             # a failed download that occurs while we aren't connected to
960             # anybody should not prevent a subsequent download from working.
961             # This isn't quite the webapi-driven test that #463 wants, but it
962             # should be close enough.
963             self._fn._storage_broker = self._storage_broker
964             return self._fn.download_best_version()
965         def _retrieved(new_contents):
966             self.failUnlessEqual(new_contents, self.CONTENTS)
967         d.addCallback(_restore)
968         d.addCallback(_retrieved)
969         return d
970     test_no_servers_download.timeout = 15
971
972     def _test_corrupt_all(self, offset, substring,
973                           should_succeed=False, corrupt_early=True,
974                           failure_checker=None):
975         d = defer.succeed(None)
976         if corrupt_early:
977             d.addCallback(corrupt, self._storage, offset)
978         d.addCallback(lambda res: self.make_servermap())
979         if not corrupt_early:
980             d.addCallback(corrupt, self._storage, offset)
981         def _do_retrieve(servermap):
982             ver = servermap.best_recoverable_version()
983             if ver is None and not should_succeed:
984                 # no recoverable versions == not succeeding. The problem
985                 # should be noted in the servermap's list of problems.
986                 if substring:
987                     allproblems = [str(f) for f in servermap.problems]
988                     self.failUnlessIn(substring, "".join(allproblems))
989                 return servermap
990             if should_succeed:
991                 d1 = self._fn.download_version(servermap, ver)
992                 d1.addCallback(lambda new_contents:
993                                self.failUnlessEqual(new_contents, self.CONTENTS))
994             else:
995                 d1 = self.shouldFail(NotEnoughSharesError,
996                                      "_corrupt_all(offset=%s)" % (offset,),
997                                      substring,
998                                      self._fn.download_version, servermap, ver)
999             if failure_checker:
1000                 d1.addCallback(failure_checker)
1001             d1.addCallback(lambda res: servermap)
1002             return d1
1003         d.addCallback(_do_retrieve)
1004         return d
1005
1006     def test_corrupt_all_verbyte(self):
1007         # when the version byte is not 0, we hit an UnknownVersionError error
1008         # in unpack_share().
1009         d = self._test_corrupt_all(0, "UnknownVersionError")
1010         def _check_servermap(servermap):
1011             # and the dump should mention the problems
1012             s = StringIO()
1013             dump = servermap.dump(s).getvalue()
1014             self.failUnless("10 PROBLEMS" in dump, dump)
1015         d.addCallback(_check_servermap)
1016         return d
1017
1018     def test_corrupt_all_seqnum(self):
1019         # a corrupt sequence number will trigger a bad signature
1020         return self._test_corrupt_all(1, "signature is invalid")
1021
1022     def test_corrupt_all_R(self):
1023         # a corrupt root hash will trigger a bad signature
1024         return self._test_corrupt_all(9, "signature is invalid")
1025
1026     def test_corrupt_all_IV(self):
1027         # a corrupt salt/IV will trigger a bad signature
1028         return self._test_corrupt_all(41, "signature is invalid")
1029
1030     def test_corrupt_all_k(self):
1031         # a corrupt 'k' will trigger a bad signature
1032         return self._test_corrupt_all(57, "signature is invalid")
1033
1034     def test_corrupt_all_N(self):
1035         # a corrupt 'N' will trigger a bad signature
1036         return self._test_corrupt_all(58, "signature is invalid")
1037
1038     def test_corrupt_all_segsize(self):
1039         # a corrupt segsize will trigger a bad signature
1040         return self._test_corrupt_all(59, "signature is invalid")
1041
1042     def test_corrupt_all_datalen(self):
1043         # a corrupt data length will trigger a bad signature
1044         return self._test_corrupt_all(67, "signature is invalid")
1045
1046     def test_corrupt_all_pubkey(self):
1047         # a corrupt pubkey won't match the URI's fingerprint. We need to
1048         # remove the pubkey from the filenode, or else it won't bother trying
1049         # to update it.
1050         self._fn._pubkey = None
1051         return self._test_corrupt_all("pubkey",
1052                                       "pubkey doesn't match fingerprint")
1053
1054     def test_corrupt_all_sig(self):
1055         # a corrupt signature is a bad one
1056         # the signature runs from about [543:799], depending upon the length
1057         # of the pubkey
1058         return self._test_corrupt_all("signature", "signature is invalid")
1059
1060     def test_corrupt_all_share_hash_chain_number(self):
1061         # a corrupt share hash chain entry will show up as a bad hash. If we
1062         # mangle the first byte, that will look like a bad hash number,
1063         # causing an IndexError
1064         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1065
1066     def test_corrupt_all_share_hash_chain_hash(self):
1067         # a corrupt share hash chain entry will show up as a bad hash. If we
1068         # mangle a few bytes in, that will look like a bad hash.
1069         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1070
1071     def test_corrupt_all_block_hash_tree(self):
1072         return self._test_corrupt_all("block_hash_tree",
1073                                       "block hash tree failure")
1074
1075     def test_corrupt_all_block(self):
1076         return self._test_corrupt_all("share_data", "block hash tree failure")
1077
1078     def test_corrupt_all_encprivkey(self):
1079         # a corrupted privkey won't even be noticed by the reader, only by a
1080         # writer.
1081         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1082
1083
1084     def test_corrupt_all_seqnum_late(self):
1085         # corrupting the seqnum between mapupdate and retrieve should result
1086         # in NotEnoughSharesError, since each share will look invalid
1087         def _check(res):
1088             f = res[0]
1089             self.failUnless(f.check(NotEnoughSharesError))
1090             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1091         return self._test_corrupt_all(1, "ran out of peers",
1092                                       corrupt_early=False,
1093                                       failure_checker=_check)
1094
1095     def test_corrupt_all_block_hash_tree_late(self):
1096         def _check(res):
1097             f = res[0]
1098             self.failUnless(f.check(NotEnoughSharesError))
1099         return self._test_corrupt_all("block_hash_tree",
1100                                       "block hash tree failure",
1101                                       corrupt_early=False,
1102                                       failure_checker=_check)
1103
1104
1105     def test_corrupt_all_block_late(self):
1106         def _check(res):
1107             f = res[0]
1108             self.failUnless(f.check(NotEnoughSharesError))
1109         return self._test_corrupt_all("share_data", "block hash tree failure",
1110                                       corrupt_early=False,
1111                                       failure_checker=_check)
1112
1113
1114     def test_basic_pubkey_at_end(self):
1115         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1116         # download to succeed but forcing a bunch of retries first. Note that
1117         # this is rather pessimistic: our Retrieve process will throw away
1118         # the whole share if the pubkey is bad, even though the rest of the
1119         # share might be good.
1120
1121         self._fn._pubkey = None
1122         k = self._fn.get_required_shares()
1123         N = self._fn.get_total_shares()
1124         d = defer.succeed(None)
1125         d.addCallback(corrupt, self._storage, "pubkey",
1126                       shnums_to_corrupt=range(0, N-k))
1127         d.addCallback(lambda res: self.make_servermap())
1128         def _do_retrieve(servermap):
1129             self.failUnless(servermap.problems)
1130             self.failUnless("pubkey doesn't match fingerprint"
1131                             in str(servermap.problems[0]))
1132             ver = servermap.best_recoverable_version()
1133             r = Retrieve(self._fn, servermap, ver)
1134             return r.download()
1135         d.addCallback(_do_retrieve)
1136         d.addCallback(lambda new_contents:
1137                       self.failUnlessEqual(new_contents, self.CONTENTS))
1138         return d
1139
1140     def test_corrupt_some(self):
1141         # corrupt the data of first five shares (so the servermap thinks
1142         # they're good but retrieve marks them as bad), so that the
1143         # MODE_READ set of 6 will be insufficient, forcing node.download to
1144         # retry with more servers.
1145         corrupt(None, self._storage, "share_data", range(5))
1146         d = self.make_servermap()
1147         def _do_retrieve(servermap):
1148             ver = servermap.best_recoverable_version()
1149             self.failUnless(ver)
1150             return self._fn.download_best_version()
1151         d.addCallback(_do_retrieve)
1152         d.addCallback(lambda new_contents:
1153                       self.failUnlessEqual(new_contents, self.CONTENTS))
1154         return d
1155
1156     def test_download_fails(self):
1157         corrupt(None, self._storage, "signature")
1158         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1159                             "no recoverable versions",
1160                             self._fn.download_best_version)
1161         return d
1162
1163
1164 class CheckerMixin:
1165     def check_good(self, r, where):
1166         self.failUnless(r.is_healthy(), where)
1167         return r
1168
1169     def check_bad(self, r, where):
1170         self.failIf(r.is_healthy(), where)
1171         return r
1172
1173     def check_expected_failure(self, r, expected_exception, substring, where):
1174         for (peerid, storage_index, shnum, f) in r.problems:
1175             if f.check(expected_exception):
1176                 self.failUnless(substring in str(f),
1177                                 "%s: substring '%s' not in '%s'" %
1178                                 (where, substring, str(f)))
1179                 return
1180         self.fail("%s: didn't see expected exception %s in problems %s" %
1181                   (where, expected_exception, r.problems))
1182
1183
1184 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1185     def setUp(self):
1186         return self.publish_one()
1187
1188
1189     def test_check_good(self):
1190         d = self._fn.check(Monitor())
1191         d.addCallback(self.check_good, "test_check_good")
1192         return d
1193
1194     def test_check_no_shares(self):
1195         for shares in self._storage._peers.values():
1196             shares.clear()
1197         d = self._fn.check(Monitor())
1198         d.addCallback(self.check_bad, "test_check_no_shares")
1199         return d
1200
1201     def test_check_not_enough_shares(self):
1202         for shares in self._storage._peers.values():
1203             for shnum in shares.keys():
1204                 if shnum > 0:
1205                     del shares[shnum]
1206         d = self._fn.check(Monitor())
1207         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1208         return d
1209
1210     def test_check_all_bad_sig(self):
1211         corrupt(None, self._storage, 1) # bad sig
1212         d = self._fn.check(Monitor())
1213         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1214         return d
1215
1216     def test_check_all_bad_blocks(self):
1217         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1218         # the Checker won't notice this.. it doesn't look at actual data
1219         d = self._fn.check(Monitor())
1220         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1221         return d
1222
1223     def test_verify_good(self):
1224         d = self._fn.check(Monitor(), verify=True)
1225         d.addCallback(self.check_good, "test_verify_good")
1226         return d
1227
1228     def test_verify_all_bad_sig(self):
1229         corrupt(None, self._storage, 1) # bad sig
1230         d = self._fn.check(Monitor(), verify=True)
1231         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1232         return d
1233
1234     def test_verify_one_bad_sig(self):
1235         corrupt(None, self._storage, 1, [9]) # bad sig
1236         d = self._fn.check(Monitor(), verify=True)
1237         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1238         return d
1239
1240     def test_verify_one_bad_block(self):
1241         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1242         # the Verifier *will* notice this, since it examines every byte
1243         d = self._fn.check(Monitor(), verify=True)
1244         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1245         d.addCallback(self.check_expected_failure,
1246                       CorruptShareError, "block hash tree failure",
1247                       "test_verify_one_bad_block")
1248         return d
1249
1250     def test_verify_one_bad_sharehash(self):
1251         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1252         d = self._fn.check(Monitor(), verify=True)
1253         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1254         d.addCallback(self.check_expected_failure,
1255                       CorruptShareError, "corrupt hashes",
1256                       "test_verify_one_bad_sharehash")
1257         return d
1258
1259     def test_verify_one_bad_encprivkey(self):
1260         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1261         d = self._fn.check(Monitor(), verify=True)
1262         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1263         d.addCallback(self.check_expected_failure,
1264                       CorruptShareError, "invalid privkey",
1265                       "test_verify_one_bad_encprivkey")
1266         return d
1267
1268     def test_verify_one_bad_encprivkey_uncheckable(self):
1269         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1270         readonly_fn = self._fn.get_readonly()
1271         # a read-only node has no way to validate the privkey
1272         d = readonly_fn.check(Monitor(), verify=True)
1273         d.addCallback(self.check_good,
1274                       "test_verify_one_bad_encprivkey_uncheckable")
1275         return d
1276
1277 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1278
1279     def get_shares(self, s):
1280         all_shares = {} # maps (peerid, shnum) to share data
1281         for peerid in s._peers:
1282             shares = s._peers[peerid]
1283             for shnum in shares:
1284                 data = shares[shnum]
1285                 all_shares[ (peerid, shnum) ] = data
1286         return all_shares
1287
1288     def copy_shares(self, ignored=None):
1289         self.old_shares.append(self.get_shares(self._storage))
1290
1291     def test_repair_nop(self):
1292         self.old_shares = []
1293         d = self.publish_one()
1294         d.addCallback(self.copy_shares)
1295         d.addCallback(lambda res: self._fn.check(Monitor()))
1296         d.addCallback(lambda check_results: self._fn.repair(check_results))
1297         def _check_results(rres):
1298             self.failUnless(IRepairResults.providedBy(rres))
1299             self.failUnless(rres.get_successful())
1300             # TODO: examine results
1301
1302             self.copy_shares()
1303
1304             initial_shares = self.old_shares[0]
1305             new_shares = self.old_shares[1]
1306             # TODO: this really shouldn't change anything. When we implement
1307             # a "minimal-bandwidth" repairer", change this test to assert:
1308             #self.failUnlessEqual(new_shares, initial_shares)
1309
1310             # all shares should be in the same place as before
1311             self.failUnlessEqual(set(initial_shares.keys()),
1312                                  set(new_shares.keys()))
1313             # but they should all be at a newer seqnum. The IV will be
1314             # different, so the roothash will be too.
1315             for key in initial_shares:
1316                 (version0,
1317                  seqnum0,
1318                  root_hash0,
1319                  IV0,
1320                  k0, N0, segsize0, datalen0,
1321                  o0) = unpack_header(initial_shares[key])
1322                 (version1,
1323                  seqnum1,
1324                  root_hash1,
1325                  IV1,
1326                  k1, N1, segsize1, datalen1,
1327                  o1) = unpack_header(new_shares[key])
1328                 self.failUnlessEqual(version0, version1)
1329                 self.failUnlessEqual(seqnum0+1, seqnum1)
1330                 self.failUnlessEqual(k0, k1)
1331                 self.failUnlessEqual(N0, N1)
1332                 self.failUnlessEqual(segsize0, segsize1)
1333                 self.failUnlessEqual(datalen0, datalen1)
1334         d.addCallback(_check_results)
1335         return d
1336
1337     def failIfSharesChanged(self, ignored=None):
1338         old_shares = self.old_shares[-2]
1339         current_shares = self.old_shares[-1]
1340         self.failUnlessEqual(old_shares, current_shares)
1341
1342     def test_unrepairable_0shares(self):
1343         d = self.publish_one()
1344         def _delete_all_shares(ign):
1345             shares = self._storage._peers
1346             for peerid in shares:
1347                 shares[peerid] = {}
1348         d.addCallback(_delete_all_shares)
1349         d.addCallback(lambda ign: self._fn.check(Monitor()))
1350         d.addCallback(lambda check_results: self._fn.repair(check_results))
1351         def _check(crr):
1352             self.failUnlessEqual(crr.get_successful(), False)
1353         d.addCallback(_check)
1354         return d
1355
1356     def test_unrepairable_1share(self):
1357         d = self.publish_one()
1358         def _delete_all_shares(ign):
1359             shares = self._storage._peers
1360             for peerid in shares:
1361                 for shnum in list(shares[peerid]):
1362                     if shnum > 0:
1363                         del shares[peerid][shnum]
1364         d.addCallback(_delete_all_shares)
1365         d.addCallback(lambda ign: self._fn.check(Monitor()))
1366         d.addCallback(lambda check_results: self._fn.repair(check_results))
1367         def _check(crr):
1368             self.failUnlessEqual(crr.get_successful(), False)
1369         d.addCallback(_check)
1370         return d
1371
1372     def test_merge(self):
1373         self.old_shares = []
1374         d = self.publish_multiple()
1375         # repair will refuse to merge multiple highest seqnums unless you
1376         # pass force=True
1377         d.addCallback(lambda res:
1378                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1379                                           1:4,3:4,5:4,7:4,9:4}))
1380         d.addCallback(self.copy_shares)
1381         d.addCallback(lambda res: self._fn.check(Monitor()))
1382         def _try_repair(check_results):
1383             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1384             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1385                                  self._fn.repair, check_results)
1386             d2.addCallback(self.copy_shares)
1387             d2.addCallback(self.failIfSharesChanged)
1388             d2.addCallback(lambda res: check_results)
1389             return d2
1390         d.addCallback(_try_repair)
1391         d.addCallback(lambda check_results:
1392                       self._fn.repair(check_results, force=True))
1393         # this should give us 10 shares of the highest roothash
1394         def _check_repair_results(rres):
1395             self.failUnless(rres.get_successful())
1396             pass # TODO
1397         d.addCallback(_check_repair_results)
1398         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1399         def _check_smap(smap):
1400             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1401             self.failIf(smap.unrecoverable_versions())
1402             # now, which should have won?
1403             roothash_s4a = self.get_roothash_for(3)
1404             roothash_s4b = self.get_roothash_for(4)
1405             if roothash_s4b > roothash_s4a:
1406                 expected_contents = self.CONTENTS[4]
1407             else:
1408                 expected_contents = self.CONTENTS[3]
1409             new_versionid = smap.best_recoverable_version()
1410             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1411             d2 = self._fn.download_version(smap, new_versionid)
1412             d2.addCallback(self.failUnlessEqual, expected_contents)
1413             return d2
1414         d.addCallback(_check_smap)
1415         return d
1416
1417     def test_non_merge(self):
1418         self.old_shares = []
1419         d = self.publish_multiple()
1420         # repair should not refuse a repair that doesn't need to merge. In
1421         # this case, we combine v2 with v3. The repair should ignore v2 and
1422         # copy v3 into a new v5.
1423         d.addCallback(lambda res:
1424                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1425                                           1:3,3:3,5:3,7:3,9:3}))
1426         d.addCallback(lambda res: self._fn.check(Monitor()))
1427         d.addCallback(lambda check_results: self._fn.repair(check_results))
1428         # this should give us 10 shares of v3
1429         def _check_repair_results(rres):
1430             self.failUnless(rres.get_successful())
1431             pass # TODO
1432         d.addCallback(_check_repair_results)
1433         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1434         def _check_smap(smap):
1435             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1436             self.failIf(smap.unrecoverable_versions())
1437             # now, which should have won?
1438             roothash_s4a = self.get_roothash_for(3)
1439             expected_contents = self.CONTENTS[3]
1440             new_versionid = smap.best_recoverable_version()
1441             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1442             d2 = self._fn.download_version(smap, new_versionid)
1443             d2.addCallback(self.failUnlessEqual, expected_contents)
1444             return d2
1445         d.addCallback(_check_smap)
1446         return d
1447
1448     def get_roothash_for(self, index):
1449         # return the roothash for the first share we see in the saved set
1450         shares = self._copied_shares[index]
1451         for peerid in shares:
1452             for shnum in shares[peerid]:
1453                 share = shares[peerid][shnum]
1454                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1455                           unpack_header(share)
1456                 return root_hash
1457
1458     def test_check_and_repair_readcap(self):
1459         # we can't currently repair from a mutable readcap: #625
1460         self.old_shares = []
1461         d = self.publish_one()
1462         d.addCallback(self.copy_shares)
1463         def _get_readcap(res):
1464             self._fn3 = self._fn.get_readonly()
1465             # also delete some shares
1466             for peerid,shares in self._storage._peers.items():
1467                 shares.pop(0, None)
1468         d.addCallback(_get_readcap)
1469         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1470         def _check_results(crr):
1471             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1472             # we should detect the unhealthy, but skip over mutable-readcap
1473             # repairs until #625 is fixed
1474             self.failIf(crr.get_pre_repair_results().is_healthy())
1475             self.failIf(crr.get_repair_attempted())
1476             self.failIf(crr.get_post_repair_results().is_healthy())
1477         d.addCallback(_check_results)
1478         return d
1479
1480 class DevNullDictionary(dict):
1481     def __setitem__(self, key, value):
1482         return
1483
1484 class MultipleEncodings(unittest.TestCase):
1485     def setUp(self):
1486         self.CONTENTS = "New contents go here"
1487         self._storage = FakeStorage()
1488         self._nodemaker = make_nodemaker(self._storage, num_peers=20)
1489         self._storage_broker = self._nodemaker.storage_broker
1490         d = self._nodemaker.create_mutable_file(self.CONTENTS)
1491         def _created(node):
1492             self._fn = node
1493         d.addCallback(_created)
1494         return d
1495
1496     def _encode(self, k, n, data):
1497         # encode 'data' into a peerid->shares dict.
1498
1499         fn = self._fn
1500         # disable the nodecache, since for these tests we explicitly need
1501         # multiple nodes pointing at the same file
1502         self._nodemaker._node_cache = DevNullDictionary()
1503         fn2 = self._nodemaker.create_from_cap(fn.get_uri())
1504         # then we copy over other fields that are normally fetched from the
1505         # existing shares
1506         fn2._pubkey = fn._pubkey
1507         fn2._privkey = fn._privkey
1508         fn2._encprivkey = fn._encprivkey
1509         # and set the encoding parameters to something completely different
1510         fn2._required_shares = k
1511         fn2._total_shares = n
1512
1513         s = self._storage
1514         s._peers = {} # clear existing storage
1515         p2 = Publish(fn2, self._storage_broker, None)
1516         d = p2.publish(data)
1517         def _published(res):
1518             shares = s._peers
1519             s._peers = {}
1520             return shares
1521         d.addCallback(_published)
1522         return d
1523
1524     def make_servermap(self, mode=MODE_READ, oldmap=None):
1525         if oldmap is None:
1526             oldmap = ServerMap()
1527         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
1528                                oldmap, mode)
1529         d = smu.update()
1530         return d
1531
1532     def test_multiple_encodings(self):
1533         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1534         # then mix up the shares, to make sure that download survives seeing
1535         # a variety of encodings. This is actually kind of tricky to set up.
1536
1537         contents1 = "Contents for encoding 1 (3-of-10) go here"
1538         contents2 = "Contents for encoding 2 (4-of-9) go here"
1539         contents3 = "Contents for encoding 3 (4-of-7) go here"
1540
1541         # we make a retrieval object that doesn't know what encoding
1542         # parameters to use
1543         fn3 = self._nodemaker.create_from_cap(self._fn.get_uri())
1544
1545         # now we upload a file through fn1, and grab its shares
1546         d = self._encode(3, 10, contents1)
1547         def _encoded_1(shares):
1548             self._shares1 = shares
1549         d.addCallback(_encoded_1)
1550         d.addCallback(lambda res: self._encode(4, 9, contents2))
1551         def _encoded_2(shares):
1552             self._shares2 = shares
1553         d.addCallback(_encoded_2)
1554         d.addCallback(lambda res: self._encode(4, 7, contents3))
1555         def _encoded_3(shares):
1556             self._shares3 = shares
1557         d.addCallback(_encoded_3)
1558
1559         def _merge(res):
1560             log.msg("merging sharelists")
1561             # we merge the shares from the two sets, leaving each shnum in
1562             # its original location, but using a share from set1 or set2
1563             # according to the following sequence:
1564             #
1565             #  4-of-9  a  s2
1566             #  4-of-9  b  s2
1567             #  4-of-7  c   s3
1568             #  4-of-9  d  s2
1569             #  3-of-9  e s1
1570             #  3-of-9  f s1
1571             #  3-of-9  g s1
1572             #  4-of-9  h  s2
1573             #
1574             # so that neither form can be recovered until fetch [f], at which
1575             # point version-s1 (the 3-of-10 form) should be recoverable. If
1576             # the implementation latches on to the first version it sees,
1577             # then s2 will be recoverable at fetch [g].
1578
1579             # Later, when we implement code that handles multiple versions,
1580             # we can use this framework to assert that all recoverable
1581             # versions are retrieved, and test that 'epsilon' does its job
1582
1583             places = [2, 2, 3, 2, 1, 1, 1, 2]
1584
1585             sharemap = {}
1586             sb = self._storage_broker
1587
1588             for peerid in sorted(sb.get_all_serverids()):
1589                 peerid_s = shortnodeid_b2a(peerid)
1590                 for shnum in self._shares1.get(peerid, {}):
1591                     if shnum < len(places):
1592                         which = places[shnum]
1593                     else:
1594                         which = "x"
1595                     self._storage._peers[peerid] = peers = {}
1596                     in_1 = shnum in self._shares1[peerid]
1597                     in_2 = shnum in self._shares2.get(peerid, {})
1598                     in_3 = shnum in self._shares3.get(peerid, {})
1599                     #print peerid_s, shnum, which, in_1, in_2, in_3
1600                     if which == 1:
1601                         if in_1:
1602                             peers[shnum] = self._shares1[peerid][shnum]
1603                             sharemap[shnum] = peerid
1604                     elif which == 2:
1605                         if in_2:
1606                             peers[shnum] = self._shares2[peerid][shnum]
1607                             sharemap[shnum] = peerid
1608                     elif which == 3:
1609                         if in_3:
1610                             peers[shnum] = self._shares3[peerid][shnum]
1611                             sharemap[shnum] = peerid
1612
1613             # we don't bother placing any other shares
1614             # now sort the sequence so that share 0 is returned first
1615             new_sequence = [sharemap[shnum]
1616                             for shnum in sorted(sharemap.keys())]
1617             self._storage._sequence = new_sequence
1618             log.msg("merge done")
1619         d.addCallback(_merge)
1620         d.addCallback(lambda res: fn3.download_best_version())
1621         def _retrieved(new_contents):
1622             # the current specified behavior is "first version recoverable"
1623             self.failUnlessEqual(new_contents, contents1)
1624         d.addCallback(_retrieved)
1625         return d
1626
1627
1628 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1629
1630     def setUp(self):
1631         return self.publish_multiple()
1632
1633     def test_multiple_versions(self):
1634         # if we see a mix of versions in the grid, download_best_version
1635         # should get the latest one
1636         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1637         d = self._fn.download_best_version()
1638         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1639         # and the checker should report problems
1640         d.addCallback(lambda res: self._fn.check(Monitor()))
1641         d.addCallback(self.check_bad, "test_multiple_versions")
1642
1643         # but if everything is at version 2, that's what we should download
1644         d.addCallback(lambda res:
1645                       self._set_versions(dict([(i,2) for i in range(10)])))
1646         d.addCallback(lambda res: self._fn.download_best_version())
1647         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1648         # if exactly one share is at version 3, we should still get v2
1649         d.addCallback(lambda res:
1650                       self._set_versions({0:3}))
1651         d.addCallback(lambda res: self._fn.download_best_version())
1652         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1653         # but the servermap should see the unrecoverable version. This
1654         # depends upon the single newer share being queried early.
1655         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1656         def _check_smap(smap):
1657             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1658             newer = smap.unrecoverable_newer_versions()
1659             self.failUnlessEqual(len(newer), 1)
1660             verinfo, health = newer.items()[0]
1661             self.failUnlessEqual(verinfo[0], 4)
1662             self.failUnlessEqual(health, (1,3))
1663             self.failIf(smap.needs_merge())
1664         d.addCallback(_check_smap)
1665         # if we have a mix of two parallel versions (s4a and s4b), we could
1666         # recover either
1667         d.addCallback(lambda res:
1668                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1669                                           1:4,3:4,5:4,7:4,9:4}))
1670         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1671         def _check_smap_mixed(smap):
1672             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1673             newer = smap.unrecoverable_newer_versions()
1674             self.failUnlessEqual(len(newer), 0)
1675             self.failUnless(smap.needs_merge())
1676         d.addCallback(_check_smap_mixed)
1677         d.addCallback(lambda res: self._fn.download_best_version())
1678         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1679                                                   res == self.CONTENTS[4]))
1680         return d
1681
1682     def test_replace(self):
1683         # if we see a mix of versions in the grid, we should be able to
1684         # replace them all with a newer version
1685
1686         # if exactly one share is at version 3, we should download (and
1687         # replace) v2, and the result should be v4. Note that the index we
1688         # give to _set_versions is different than the sequence number.
1689         target = dict([(i,2) for i in range(10)]) # seqnum3
1690         target[0] = 3 # seqnum4
1691         self._set_versions(target)
1692
1693         def _modify(oldversion, servermap, first_time):
1694             return oldversion + " modified"
1695         d = self._fn.modify(_modify)
1696         d.addCallback(lambda res: self._fn.download_best_version())
1697         expected = self.CONTENTS[2] + " modified"
1698         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1699         # and the servermap should indicate that the outlier was replaced too
1700         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1701         def _check_smap(smap):
1702             self.failUnlessEqual(smap.highest_seqnum(), 5)
1703             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1704             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1705         d.addCallback(_check_smap)
1706         return d
1707
1708
1709 class Utils(unittest.TestCase):
1710     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1711         # we compare this against sets of integers
1712         x = set(range(x_start, x_start+x_length))
1713         y = set(range(y_start, y_start+y_length))
1714         should_be_inside = x.issubset(y)
1715         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1716                                                          y_start, y_length),
1717                              str((x_start, x_length, y_start, y_length)))
1718
1719     def test_cache_inside(self):
1720         c = ResponseCache()
1721         x_start = 10
1722         x_length = 5
1723         for y_start in range(8, 17):
1724             for y_length in range(8):
1725                 self._do_inside(c, x_start, x_length, y_start, y_length)
1726
1727     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1728         # we compare this against sets of integers
1729         x = set(range(x_start, x_start+x_length))
1730         y = set(range(y_start, y_start+y_length))
1731         overlap = bool(x.intersection(y))
1732         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1733                                                       y_start, y_length),
1734                              str((x_start, x_length, y_start, y_length)))
1735
1736     def test_cache_overlap(self):
1737         c = ResponseCache()
1738         x_start = 10
1739         x_length = 5
1740         for y_start in range(8, 17):
1741             for y_length in range(8):
1742                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1743
1744     def test_cache(self):
1745         c = ResponseCache()
1746         # xdata = base62.b2a(os.urandom(100))[:100]
1747         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1748         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1749         nope = (None, None)
1750         c.add("v1", 1, 0, xdata, "time0")
1751         c.add("v1", 1, 2000, ydata, "time1")
1752         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1753         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1754         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1755         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1756         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1757         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1758         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1759         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1760         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1761         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1762         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1763         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1764         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1765         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1766         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1767         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1768         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1769         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1770
1771         # optional: join fragments
1772         c = ResponseCache()
1773         c.add("v1", 1, 0, xdata[:10], "time0")
1774         c.add("v1", 1, 10, xdata[10:20], "time1")
1775         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1776
1777 class Exceptions(unittest.TestCase):
1778     def test_repr(self):
1779         nmde = NeedMoreDataError(100, 50, 100)
1780         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1781         ucwe = UncoordinatedWriteError()
1782         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1783
1784 class SameKeyGenerator:
1785     def __init__(self, pubkey, privkey):
1786         self.pubkey = pubkey
1787         self.privkey = privkey
1788     def generate(self, keysize=None):
1789         return defer.succeed( (self.pubkey, self.privkey) )
1790
1791 class FirstServerGetsKilled:
1792     done = False
1793     def notify(self, retval, wrapper, methname):
1794         if not self.done:
1795             wrapper.broken = True
1796             self.done = True
1797         return retval
1798
1799 class FirstServerGetsDeleted:
1800     def __init__(self):
1801         self.done = False
1802         self.silenced = None
1803     def notify(self, retval, wrapper, methname):
1804         if not self.done:
1805             # this query will work, but later queries should think the share
1806             # has been deleted
1807             self.done = True
1808             self.silenced = wrapper
1809             return retval
1810         if wrapper == self.silenced:
1811             assert methname == "slot_testv_and_readv_and_writev"
1812             return (True, {})
1813         return retval
1814
1815 class Problems(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
1816     def test_publish_surprise(self):
1817         self.basedir = "mutable/Problems/test_publish_surprise"
1818         self.set_up_grid()
1819         nm = self.g.clients[0].nodemaker
1820         d = nm.create_mutable_file("contents 1")
1821         def _created(n):
1822             d = defer.succeed(None)
1823             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1824             def _got_smap1(smap):
1825                 # stash the old state of the file
1826                 self.old_map = smap
1827             d.addCallback(_got_smap1)
1828             # then modify the file, leaving the old map untouched
1829             d.addCallback(lambda res: log.msg("starting winning write"))
1830             d.addCallback(lambda res: n.overwrite("contents 2"))
1831             # now attempt to modify the file with the old servermap. This
1832             # will look just like an uncoordinated write, in which every
1833             # single share got updated between our mapupdate and our publish
1834             d.addCallback(lambda res: log.msg("starting doomed write"))
1835             d.addCallback(lambda res:
1836                           self.shouldFail(UncoordinatedWriteError,
1837                                           "test_publish_surprise", None,
1838                                           n.upload,
1839                                           "contents 2a", self.old_map))
1840             return d
1841         d.addCallback(_created)
1842         return d
1843
1844     def test_retrieve_surprise(self):
1845         self.basedir = "mutable/Problems/test_retrieve_surprise"
1846         self.set_up_grid()
1847         nm = self.g.clients[0].nodemaker
1848         d = nm.create_mutable_file("contents 1")
1849         def _created(n):
1850             d = defer.succeed(None)
1851             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1852             def _got_smap1(smap):
1853                 # stash the old state of the file
1854                 self.old_map = smap
1855             d.addCallback(_got_smap1)
1856             # then modify the file, leaving the old map untouched
1857             d.addCallback(lambda res: log.msg("starting winning write"))
1858             d.addCallback(lambda res: n.overwrite("contents 2"))
1859             # now attempt to retrieve the old version with the old servermap.
1860             # This will look like someone has changed the file since we
1861             # updated the servermap.
1862             d.addCallback(lambda res: n._cache._clear())
1863             d.addCallback(lambda res: log.msg("starting doomed read"))
1864             d.addCallback(lambda res:
1865                           self.shouldFail(NotEnoughSharesError,
1866                                           "test_retrieve_surprise",
1867                                           "ran out of peers: have 0 shares (k=3)",
1868                                           n.download_version,
1869                                           self.old_map,
1870                                           self.old_map.best_recoverable_version(),
1871                                           ))
1872             return d
1873         d.addCallback(_created)
1874         return d
1875
1876     def test_unexpected_shares(self):
1877         # upload the file, take a servermap, shut down one of the servers,
1878         # upload it again (causing shares to appear on a new server), then
1879         # upload using the old servermap. The last upload should fail with an
1880         # UncoordinatedWriteError, because of the shares that didn't appear
1881         # in the servermap.
1882         self.basedir = "mutable/Problems/test_unexpected_shares"
1883         self.set_up_grid()
1884         nm = self.g.clients[0].nodemaker
1885         d = nm.create_mutable_file("contents 1")
1886         def _created(n):
1887             d = defer.succeed(None)
1888             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1889             def _got_smap1(smap):
1890                 # stash the old state of the file
1891                 self.old_map = smap
1892                 # now shut down one of the servers
1893                 peer0 = list(smap.make_sharemap()[0])[0]
1894                 self.g.remove_server(peer0)
1895                 # then modify the file, leaving the old map untouched
1896                 log.msg("starting winning write")
1897                 return n.overwrite("contents 2")
1898             d.addCallback(_got_smap1)
1899             # now attempt to modify the file with the old servermap. This
1900             # will look just like an uncoordinated write, in which every
1901             # single share got updated between our mapupdate and our publish
1902             d.addCallback(lambda res: log.msg("starting doomed write"))
1903             d.addCallback(lambda res:
1904                           self.shouldFail(UncoordinatedWriteError,
1905                                           "test_surprise", None,
1906                                           n.upload,
1907                                           "contents 2a", self.old_map))
1908             return d
1909         d.addCallback(_created)
1910         return d
1911
1912     def test_bad_server(self):
1913         # Break one server, then create the file: the initial publish should
1914         # complete with an alternate server. Breaking a second server should
1915         # not prevent an update from succeeding either.
1916         self.basedir = "mutable/Problems/test_bad_server"
1917         self.set_up_grid()
1918         nm = self.g.clients[0].nodemaker
1919
1920         # to make sure that one of the initial peers is broken, we have to
1921         # get creative. We create an RSA key and compute its storage-index.
1922         # Then we make a KeyGenerator that always returns that one key, and
1923         # use it to create the mutable file. This will get easier when we can
1924         # use #467 static-server-selection to disable permutation and force
1925         # the choice of server for share[0].
1926
1927         d = nm.key_generator.generate(522)
1928         def _got_key( (pubkey, privkey) ):
1929             nm.key_generator = SameKeyGenerator(pubkey, privkey)
1930             pubkey_s = pubkey.serialize()
1931             privkey_s = privkey.serialize()
1932             u = uri.WriteableSSKFileURI(ssk_writekey_hash(privkey_s),
1933                                         ssk_pubkey_fingerprint_hash(pubkey_s))
1934             self._storage_index = u.storage_index
1935         d.addCallback(_got_key)
1936         def _break_peer0(res):
1937             si = self._storage_index
1938             peerlist = nm.storage_broker.get_servers_for_index(si)
1939             peerid0, connection0 = peerlist[0]
1940             peerid1, connection1 = peerlist[1]
1941             connection0.broken = True
1942             self.connection1 = connection1
1943         d.addCallback(_break_peer0)
1944         # now "create" the file, using the pre-established key, and let the
1945         # initial publish finally happen
1946         d.addCallback(lambda res: nm.create_mutable_file("contents 1"))
1947         # that ought to work
1948         def _got_node(n):
1949             d = n.download_best_version()
1950             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1951             # now break the second peer
1952             def _break_peer1(res):
1953                 self.connection1.broken = True
1954             d.addCallback(_break_peer1)
1955             d.addCallback(lambda res: n.overwrite("contents 2"))
1956             # that ought to work too
1957             d.addCallback(lambda res: n.download_best_version())
1958             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1959             def _explain_error(f):
1960                 print f
1961                 if f.check(NotEnoughServersError):
1962                     print "first_error:", f.value.first_error
1963                 return f
1964             d.addErrback(_explain_error)
1965             return d
1966         d.addCallback(_got_node)
1967         return d
1968
1969     def test_bad_server_overlap(self):
1970         # like test_bad_server, but with no extra unused servers to fall back
1971         # upon. This means that we must re-use a server which we've already
1972         # used. If we don't remember the fact that we sent them one share
1973         # already, we'll mistakenly think we're experiencing an
1974         # UncoordinatedWriteError.
1975
1976         # Break one server, then create the file: the initial publish should
1977         # complete with an alternate server. Breaking a second server should
1978         # not prevent an update from succeeding either.
1979         self.basedir = "mutable/Problems/test_bad_server_overlap"
1980         self.set_up_grid()
1981         nm = self.g.clients[0].nodemaker
1982         sb = nm.storage_broker
1983
1984         peerids = [serverid for (serverid,ss) in sb.get_all_servers()]
1985         self.g.break_server(peerids[0])
1986
1987         d = nm.create_mutable_file("contents 1")
1988         def _created(n):
1989             d = n.download_best_version()
1990             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1991             # now break one of the remaining servers
1992             def _break_second_server(res):
1993                 self.g.break_server(peerids[1])
1994             d.addCallback(_break_second_server)
1995             d.addCallback(lambda res: n.overwrite("contents 2"))
1996             # that ought to work too
1997             d.addCallback(lambda res: n.download_best_version())
1998             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1999             return d
2000         d.addCallback(_created)
2001         return d
2002
2003     def test_publish_all_servers_bad(self):
2004         # Break all servers: the publish should fail
2005         self.basedir = "mutable/Problems/test_publish_all_servers_bad"
2006         self.set_up_grid()
2007         nm = self.g.clients[0].nodemaker
2008         for (serverid,ss) in nm.storage_broker.get_all_servers():
2009             ss.broken = True
2010
2011         d = self.shouldFail(NotEnoughServersError,
2012                             "test_publish_all_servers_bad",
2013                             "Ran out of non-bad servers",
2014                             nm.create_mutable_file, "contents")
2015         return d
2016
2017     def test_publish_no_servers(self):
2018         # no servers at all: the publish should fail
2019         self.basedir = "mutable/Problems/test_publish_no_servers"
2020         self.set_up_grid(num_servers=0)
2021         nm = self.g.clients[0].nodemaker
2022
2023         d = self.shouldFail(NotEnoughServersError,
2024                             "test_publish_no_servers",
2025                             "Ran out of non-bad servers",
2026                             nm.create_mutable_file, "contents")
2027         return d
2028     test_publish_no_servers.timeout = 30
2029
2030
2031     def test_privkey_query_error(self):
2032         # when a servermap is updated with MODE_WRITE, it tries to get the
2033         # privkey. Something might go wrong during this query attempt.
2034         # Exercise the code in _privkey_query_failed which tries to handle
2035         # such an error.
2036         self.basedir = "mutable/Problems/test_privkey_query_error"
2037         self.set_up_grid(num_servers=20)
2038         nm = self.g.clients[0].nodemaker
2039         nm._node_cache = DevNullDictionary() # disable the nodecache
2040
2041         # we need some contents that are large enough to push the privkey out
2042         # of the early part of the file
2043         LARGE = "These are Larger contents" * 2000 # about 50KB
2044         d = nm.create_mutable_file(LARGE)
2045         def _created(n):
2046             self.uri = n.get_uri()
2047             self.n2 = nm.create_from_cap(self.uri)
2048
2049             # When a mapupdate is performed on a node that doesn't yet know
2050             # the privkey, a short read is sent to a batch of servers, to get
2051             # the verinfo and (hopefully, if the file is short enough) the
2052             # encprivkey. Our file is too large to let this first read
2053             # contain the encprivkey. Each non-encprivkey-bearing response
2054             # that arrives (until the node gets the encprivkey) will trigger
2055             # a second read to specifically read the encprivkey.
2056             #
2057             # So, to exercise this case:
2058             #  1. notice which server gets a read() call first
2059             #  2. tell that server to start throwing errors
2060             killer = FirstServerGetsKilled()
2061             for (serverid,ss) in nm.storage_broker.get_all_servers():
2062                 ss.post_call_notifier = killer.notify
2063         d.addCallback(_created)
2064
2065         # now we update a servermap from a new node (which doesn't have the
2066         # privkey yet, forcing it to use a separate privkey query). Note that
2067         # the map-update will succeed, since we'll just get a copy from one
2068         # of the other shares.
2069         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2070
2071         return d
2072
2073     def test_privkey_query_missing(self):
2074         # like test_privkey_query_error, but the shares are deleted by the
2075         # second query, instead of raising an exception.
2076         self.basedir = "mutable/Problems/test_privkey_query_missing"
2077         self.set_up_grid(num_servers=20)
2078         nm = self.g.clients[0].nodemaker
2079         LARGE = "These are Larger contents" * 2000 # about 50KB
2080         nm._node_cache = DevNullDictionary() # disable the nodecache
2081
2082         d = nm.create_mutable_file(LARGE)
2083         def _created(n):
2084             self.uri = n.get_uri()
2085             self.n2 = nm.create_from_cap(self.uri)
2086             deleter = FirstServerGetsDeleted()
2087             for (serverid,ss) in nm.storage_broker.get_all_servers():
2088                 ss.post_call_notifier = deleter.notify
2089         d.addCallback(_created)
2090         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2091         return d