]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_mutable.py: add test for ResponseCache memory leak. refs #1045, #1129
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from allmydata import uri, client
7 from allmydata.nodemaker import NodeMaker
8 from allmydata.util import base32
9 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
10      ssk_pubkey_fingerprint_hash
11 from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
12      NotEnoughSharesError
13 from allmydata.monitor import Monitor
14 from allmydata.test.common import ShouldFailMixin
15 from allmydata.test.no_network import GridTestMixin
16 from foolscap.api import eventually, fireEventually
17 from foolscap.logging import log
18 from allmydata.storage_client import StorageFarmBroker
19
20 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
21 from allmydata.mutable.common import ResponseCache, \
22      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
23      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
24      NotEnoughServersError, CorruptShareError
25 from allmydata.mutable.retrieve import Retrieve
26 from allmydata.mutable.publish import Publish
27 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
28 from allmydata.mutable.layout import unpack_header, unpack_share
29 from allmydata.mutable.repairer import MustForceRepairError
30
31 import allmydata.test.common_util as testutil
32
33 # this "FakeStorage" exists to put the share data in RAM and avoid using real
34 # network connections, both to speed up the tests and to reduce the amount of
35 # non-mutable.py code being exercised.
36
37 class FakeStorage:
38     # this class replaces the collection of storage servers, allowing the
39     # tests to examine and manipulate the published shares. It also lets us
40     # control the order in which read queries are answered, to exercise more
41     # of the error-handling code in Retrieve .
42     #
43     # Note that we ignore the storage index: this FakeStorage instance can
44     # only be used for a single storage index.
45
46
47     def __init__(self):
48         self._peers = {}
49         # _sequence is used to cause the responses to occur in a specific
50         # order. If it is in use, then we will defer queries instead of
51         # answering them right away, accumulating the Deferreds in a dict. We
52         # don't know exactly how many queries we'll get, so exactly one
53         # second after the first query arrives, we will release them all (in
54         # order).
55         self._sequence = None
56         self._pending = {}
57         self._pending_timer = None
58
59     def read(self, peerid, storage_index):
60         shares = self._peers.get(peerid, {})
61         if self._sequence is None:
62             return defer.succeed(shares)
63         d = defer.Deferred()
64         if not self._pending:
65             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
66         self._pending[peerid] = (d, shares)
67         return d
68
69     def _fire_readers(self):
70         self._pending_timer = None
71         pending = self._pending
72         self._pending = {}
73         for peerid in self._sequence:
74             if peerid in pending:
75                 d, shares = pending.pop(peerid)
76                 eventually(d.callback, shares)
77         for (d, shares) in pending.values():
78             eventually(d.callback, shares)
79
80     def write(self, peerid, storage_index, shnum, offset, data):
81         if peerid not in self._peers:
82             self._peers[peerid] = {}
83         shares = self._peers[peerid]
84         f = StringIO()
85         f.write(shares.get(shnum, ""))
86         f.seek(offset)
87         f.write(data)
88         shares[shnum] = f.getvalue()
89
90
91 class FakeStorageServer:
92     def __init__(self, peerid, storage):
93         self.peerid = peerid
94         self.storage = storage
95         self.queries = 0
96     def callRemote(self, methname, *args, **kwargs):
97         def _call():
98             meth = getattr(self, methname)
99             return meth(*args, **kwargs)
100         d = fireEventually()
101         d.addCallback(lambda res: _call())
102         return d
103     def callRemoteOnly(self, methname, *args, **kwargs):
104         d = self.callRemote(methname, *args, **kwargs)
105         d.addBoth(lambda ignore: None)
106         pass
107
108     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
109         pass
110
111     def slot_readv(self, storage_index, shnums, readv):
112         d = self.storage.read(self.peerid, storage_index)
113         def _read(shares):
114             response = {}
115             for shnum in shares:
116                 if shnums and shnum not in shnums:
117                     continue
118                 vector = response[shnum] = []
119                 for (offset, length) in readv:
120                     assert isinstance(offset, (int, long)), offset
121                     assert isinstance(length, (int, long)), length
122                     vector.append(shares[shnum][offset:offset+length])
123             return response
124         d.addCallback(_read)
125         return d
126
127     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
128                                         tw_vectors, read_vector):
129         # always-pass: parrot the test vectors back to them.
130         readv = {}
131         for shnum, (testv, writev, new_length) in tw_vectors.items():
132             for (offset, length, op, specimen) in testv:
133                 assert op in ("le", "eq", "ge")
134             # TODO: this isn't right, the read is controlled by read_vector,
135             # not by testv
136             readv[shnum] = [ specimen
137                              for (offset, length, op, specimen)
138                              in testv ]
139             for (offset, data) in writev:
140                 self.storage.write(self.peerid, storage_index, shnum,
141                                    offset, data)
142         answer = (True, readv)
143         return fireEventually(answer)
144
145
146 def flip_bit(original, byte_offset):
147     return (original[:byte_offset] +
148             chr(ord(original[byte_offset]) ^ 0x01) +
149             original[byte_offset+1:])
150
151 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
152     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
153     # list of shnums to corrupt.
154     for peerid in s._peers:
155         shares = s._peers[peerid]
156         for shnum in shares:
157             if (shnums_to_corrupt is not None
158                 and shnum not in shnums_to_corrupt):
159                 continue
160             data = shares[shnum]
161             (version,
162              seqnum,
163              root_hash,
164              IV,
165              k, N, segsize, datalen,
166              o) = unpack_header(data)
167             if isinstance(offset, tuple):
168                 offset1, offset2 = offset
169             else:
170                 offset1 = offset
171                 offset2 = 0
172             if offset1 == "pubkey":
173                 real_offset = 107
174             elif offset1 in o:
175                 real_offset = o[offset1]
176             else:
177                 real_offset = offset1
178             real_offset = int(real_offset) + offset2 + offset_offset
179             assert isinstance(real_offset, int), offset
180             shares[shnum] = flip_bit(data, real_offset)
181     return res
182
183 def make_storagebroker(s=None, num_peers=10):
184     if not s:
185         s = FakeStorage()
186     peerids = [tagged_hash("peerid", "%d" % i)[:20]
187                for i in range(num_peers)]
188     storage_broker = StorageFarmBroker(None, True)
189     for peerid in peerids:
190         fss = FakeStorageServer(peerid, s)
191         storage_broker.test_add_server(peerid, fss)
192     return storage_broker
193
194 def make_nodemaker(s=None, num_peers=10):
195     storage_broker = make_storagebroker(s, num_peers)
196     sh = client.SecretHolder("lease secret", "convergence secret")
197     keygen = client.KeyGenerator()
198     keygen.set_default_keysize(522)
199     nodemaker = NodeMaker(storage_broker, sh, None,
200                           None, None,
201                           {"k": 3, "n": 10}, keygen)
202     return nodemaker
203
204 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
205     # this used to be in Publish, but we removed the limit. Some of
206     # these tests test whether the new code correctly allows files
207     # larger than the limit.
208     OLD_MAX_SEGMENT_SIZE = 3500000
209     def setUp(self):
210         self._storage = s = FakeStorage()
211         self.nodemaker = make_nodemaker(s)
212
213     def test_create(self):
214         d = self.nodemaker.create_mutable_file()
215         def _created(n):
216             self.failUnless(isinstance(n, MutableFileNode))
217             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
218             sb = self.nodemaker.storage_broker
219             peer0 = sorted(sb.get_all_serverids())[0]
220             shnums = self._storage._peers[peer0].keys()
221             self.failUnlessEqual(len(shnums), 1)
222         d.addCallback(_created)
223         return d
224
225     def test_serialize(self):
226         n = MutableFileNode(None, None, {"k": 3, "n": 10}, None)
227         calls = []
228         def _callback(*args, **kwargs):
229             self.failUnlessEqual(args, (4,) )
230             self.failUnlessEqual(kwargs, {"foo": 5})
231             calls.append(1)
232             return 6
233         d = n._do_serialized(_callback, 4, foo=5)
234         def _check_callback(res):
235             self.failUnlessEqual(res, 6)
236             self.failUnlessEqual(calls, [1])
237         d.addCallback(_check_callback)
238
239         def _errback():
240             raise ValueError("heya")
241         d.addCallback(lambda res:
242                       self.shouldFail(ValueError, "_check_errback", "heya",
243                                       n._do_serialized, _errback))
244         return d
245
246     def test_upload_and_download(self):
247         d = self.nodemaker.create_mutable_file()
248         def _created(n):
249             d = defer.succeed(None)
250             d.addCallback(lambda res: n.get_servermap(MODE_READ))
251             d.addCallback(lambda smap: smap.dump(StringIO()))
252             d.addCallback(lambda sio:
253                           self.failUnless("3-of-10" in sio.getvalue()))
254             d.addCallback(lambda res: n.overwrite("contents 1"))
255             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
256             d.addCallback(lambda res: n.download_best_version())
257             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
258             d.addCallback(lambda res: n.get_size_of_best_version())
259             d.addCallback(lambda size:
260                           self.failUnlessEqual(size, len("contents 1")))
261             d.addCallback(lambda res: n.overwrite("contents 2"))
262             d.addCallback(lambda res: n.download_best_version())
263             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
264             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
265             d.addCallback(lambda smap: n.upload("contents 3", smap))
266             d.addCallback(lambda res: n.download_best_version())
267             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
268             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
269             d.addCallback(lambda smap:
270                           n.download_version(smap,
271                                              smap.best_recoverable_version()))
272             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
273             # test a file that is large enough to overcome the
274             # mapupdate-to-retrieve data caching (i.e. make the shares larger
275             # than the default readsize, which is 2000 bytes). A 15kB file
276             # will have 5kB shares.
277             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
278             d.addCallback(lambda res: n.download_best_version())
279             d.addCallback(lambda res:
280                           self.failUnlessEqual(res, "large size file" * 1000))
281             return d
282         d.addCallback(_created)
283         return d
284
285     def test_create_with_initial_contents(self):
286         d = self.nodemaker.create_mutable_file("contents 1")
287         def _created(n):
288             d = n.download_best_version()
289             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
290             d.addCallback(lambda res: n.overwrite("contents 2"))
291             d.addCallback(lambda res: n.download_best_version())
292             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
293             return d
294         d.addCallback(_created)
295         return d
296
297     def test_response_cache_memory_leak(self):
298         d = self.nodemaker.create_mutable_file("contents")
299         def _created(n):
300             d = n.download_best_version()
301             d.addCallback(lambda res: self.failUnlessEqual(res, "contents"))
302             d.addCallback(lambda ign: self.failUnless(isinstance(n._cache, ResponseCache)))
303
304             def _check_cache_size(expected):
305                 # The total size of cache entries should not increase on the second download.
306                 d2 = n.download_best_version()
307                 d2.addCallback(lambda ign: self.failUnlessEqual(len(repr(n._cache.cache)), expected))
308                 return d2
309             d.addCallback(lambda ign: _check_cache_size(len(repr(n._cache.cache))))
310             return d
311         d.addCallback(_created)
312         return d
313     test_response_cache_memory_leak.todo = "This isn't fixed (see #1045)."
314
315     def test_create_with_initial_contents_function(self):
316         data = "initial contents"
317         def _make_contents(n):
318             self.failUnless(isinstance(n, MutableFileNode))
319             key = n.get_writekey()
320             self.failUnless(isinstance(key, str), key)
321             self.failUnlessEqual(len(key), 16) # AES key size
322             return data
323         d = self.nodemaker.create_mutable_file(_make_contents)
324         def _created(n):
325             return n.download_best_version()
326         d.addCallback(_created)
327         d.addCallback(lambda data2: self.failUnlessEqual(data2, data))
328         return d
329
330     def test_create_with_too_large_contents(self):
331         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
332         d = self.nodemaker.create_mutable_file(BIG)
333         def _created(n):
334             d = n.overwrite(BIG)
335             return d
336         d.addCallback(_created)
337         return d
338
339     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
340         d = n.get_servermap(MODE_READ)
341         d.addCallback(lambda servermap: servermap.best_recoverable_version())
342         d.addCallback(lambda verinfo:
343                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
344         return d
345
346     def test_modify(self):
347         def _modifier(old_contents, servermap, first_time):
348             return old_contents + "line2"
349         def _non_modifier(old_contents, servermap, first_time):
350             return old_contents
351         def _none_modifier(old_contents, servermap, first_time):
352             return None
353         def _error_modifier(old_contents, servermap, first_time):
354             raise ValueError("oops")
355         def _toobig_modifier(old_contents, servermap, first_time):
356             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
357         calls = []
358         def _ucw_error_modifier(old_contents, servermap, first_time):
359             # simulate an UncoordinatedWriteError once
360             calls.append(1)
361             if len(calls) <= 1:
362                 raise UncoordinatedWriteError("simulated")
363             return old_contents + "line3"
364         def _ucw_error_non_modifier(old_contents, servermap, first_time):
365             # simulate an UncoordinatedWriteError once, and don't actually
366             # modify the contents on subsequent invocations
367             calls.append(1)
368             if len(calls) <= 1:
369                 raise UncoordinatedWriteError("simulated")
370             return old_contents
371
372         d = self.nodemaker.create_mutable_file("line1")
373         def _created(n):
374             d = n.modify(_modifier)
375             d.addCallback(lambda res: n.download_best_version())
376             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
377             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
378
379             d.addCallback(lambda res: n.modify(_non_modifier))
380             d.addCallback(lambda res: n.download_best_version())
381             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
382             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
383
384             d.addCallback(lambda res: n.modify(_none_modifier))
385             d.addCallback(lambda res: n.download_best_version())
386             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
387             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
388
389             d.addCallback(lambda res:
390                           self.shouldFail(ValueError, "error_modifier", None,
391                                           n.modify, _error_modifier))
392             d.addCallback(lambda res: n.download_best_version())
393             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
394             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
395
396
397             d.addCallback(lambda res: n.download_best_version())
398             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
399             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
400
401             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
402             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
403             d.addCallback(lambda res: n.download_best_version())
404             d.addCallback(lambda res: self.failUnlessEqual(res,
405                                                            "line1line2line3"))
406             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
407
408             def _reset_ucw_error_modifier(res):
409                 calls[:] = []
410                 return res
411             d.addCallback(_reset_ucw_error_modifier)
412
413             # in practice, this n.modify call should publish twice: the first
414             # one gets a UCWE, the second does not. But our test jig (in
415             # which the modifier raises the UCWE) skips over the first one,
416             # so in this test there will be only one publish, and the seqnum
417             # will only be one larger than the previous test, not two (i.e. 4
418             # instead of 5).
419             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
420             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
421             d.addCallback(lambda res: n.download_best_version())
422             d.addCallback(lambda res: self.failUnlessEqual(res,
423                                                            "line1line2line3"))
424             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
425             d.addCallback(lambda res: n.modify(_toobig_modifier))
426             return d
427         d.addCallback(_created)
428         return d
429
430     def test_modify_backoffer(self):
431         def _modifier(old_contents, servermap, first_time):
432             return old_contents + "line2"
433         calls = []
434         def _ucw_error_modifier(old_contents, servermap, first_time):
435             # simulate an UncoordinatedWriteError once
436             calls.append(1)
437             if len(calls) <= 1:
438                 raise UncoordinatedWriteError("simulated")
439             return old_contents + "line3"
440         def _always_ucw_error_modifier(old_contents, servermap, first_time):
441             raise UncoordinatedWriteError("simulated")
442         def _backoff_stopper(node, f):
443             return f
444         def _backoff_pauser(node, f):
445             d = defer.Deferred()
446             reactor.callLater(0.5, d.callback, None)
447             return d
448
449         # the give-up-er will hit its maximum retry count quickly
450         giveuper = BackoffAgent()
451         giveuper._delay = 0.1
452         giveuper.factor = 1
453
454         d = self.nodemaker.create_mutable_file("line1")
455         def _created(n):
456             d = n.modify(_modifier)
457             d.addCallback(lambda res: n.download_best_version())
458             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
459             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
460
461             d.addCallback(lambda res:
462                           self.shouldFail(UncoordinatedWriteError,
463                                           "_backoff_stopper", None,
464                                           n.modify, _ucw_error_modifier,
465                                           _backoff_stopper))
466             d.addCallback(lambda res: n.download_best_version())
467             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
468             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
469
470             def _reset_ucw_error_modifier(res):
471                 calls[:] = []
472                 return res
473             d.addCallback(_reset_ucw_error_modifier)
474             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
475                                                _backoff_pauser))
476             d.addCallback(lambda res: n.download_best_version())
477             d.addCallback(lambda res: self.failUnlessEqual(res,
478                                                            "line1line2line3"))
479             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
480
481             d.addCallback(lambda res:
482                           self.shouldFail(UncoordinatedWriteError,
483                                           "giveuper", None,
484                                           n.modify, _always_ucw_error_modifier,
485                                           giveuper.delay))
486             d.addCallback(lambda res: n.download_best_version())
487             d.addCallback(lambda res: self.failUnlessEqual(res,
488                                                            "line1line2line3"))
489             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
490
491             return d
492         d.addCallback(_created)
493         return d
494
495     def test_upload_and_download_full_size_keys(self):
496         self.nodemaker.key_generator = client.KeyGenerator()
497         d = self.nodemaker.create_mutable_file()
498         def _created(n):
499             d = defer.succeed(None)
500             d.addCallback(lambda res: n.get_servermap(MODE_READ))
501             d.addCallback(lambda smap: smap.dump(StringIO()))
502             d.addCallback(lambda sio:
503                           self.failUnless("3-of-10" in sio.getvalue()))
504             d.addCallback(lambda res: n.overwrite("contents 1"))
505             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
506             d.addCallback(lambda res: n.download_best_version())
507             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
508             d.addCallback(lambda res: n.overwrite("contents 2"))
509             d.addCallback(lambda res: n.download_best_version())
510             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
511             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
512             d.addCallback(lambda smap: n.upload("contents 3", smap))
513             d.addCallback(lambda res: n.download_best_version())
514             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
515             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
516             d.addCallback(lambda smap:
517                           n.download_version(smap,
518                                              smap.best_recoverable_version()))
519             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
520             return d
521         d.addCallback(_created)
522         return d
523
524
525 class MakeShares(unittest.TestCase):
526     def test_encrypt(self):
527         nm = make_nodemaker()
528         CONTENTS = "some initial contents"
529         d = nm.create_mutable_file(CONTENTS)
530         def _created(fn):
531             p = Publish(fn, nm.storage_broker, None)
532             p.salt = "SALT" * 4
533             p.readkey = "\x00" * 16
534             p.newdata = CONTENTS
535             p.required_shares = 3
536             p.total_shares = 10
537             p.setup_encoding_parameters()
538             return p._encrypt_and_encode()
539         d.addCallback(_created)
540         def _done(shares_and_shareids):
541             (shares, share_ids) = shares_and_shareids
542             self.failUnlessEqual(len(shares), 10)
543             for sh in shares:
544                 self.failUnless(isinstance(sh, str))
545                 self.failUnlessEqual(len(sh), 7)
546             self.failUnlessEqual(len(share_ids), 10)
547         d.addCallback(_done)
548         return d
549
550     def test_generate(self):
551         nm = make_nodemaker()
552         CONTENTS = "some initial contents"
553         d = nm.create_mutable_file(CONTENTS)
554         def _created(fn):
555             self._fn = fn
556             p = Publish(fn, nm.storage_broker, None)
557             self._p = p
558             p.newdata = CONTENTS
559             p.required_shares = 3
560             p.total_shares = 10
561             p.setup_encoding_parameters()
562             p._new_seqnum = 3
563             p.salt = "SALT" * 4
564             # make some fake shares
565             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
566             p._privkey = fn.get_privkey()
567             p._encprivkey = fn.get_encprivkey()
568             p._pubkey = fn.get_pubkey()
569             return p._generate_shares(shares_and_ids)
570         d.addCallback(_created)
571         def _generated(res):
572             p = self._p
573             final_shares = p.shares
574             root_hash = p.root_hash
575             self.failUnlessEqual(len(root_hash), 32)
576             self.failUnless(isinstance(final_shares, dict))
577             self.failUnlessEqual(len(final_shares), 10)
578             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
579             for i,sh in final_shares.items():
580                 self.failUnless(isinstance(sh, str))
581                 # feed the share through the unpacker as a sanity-check
582                 pieces = unpack_share(sh)
583                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
584                  pubkey, signature, share_hash_chain, block_hash_tree,
585                  share_data, enc_privkey) = pieces
586                 self.failUnlessEqual(u_seqnum, 3)
587                 self.failUnlessEqual(u_root_hash, root_hash)
588                 self.failUnlessEqual(k, 3)
589                 self.failUnlessEqual(N, 10)
590                 self.failUnlessEqual(segsize, 21)
591                 self.failUnlessEqual(datalen, len(CONTENTS))
592                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
593                 sig_material = struct.pack(">BQ32s16s BBQQ",
594                                            0, p._new_seqnum, root_hash, IV,
595                                            k, N, segsize, datalen)
596                 self.failUnless(p._pubkey.verify(sig_material, signature))
597                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
598                 self.failUnless(isinstance(share_hash_chain, dict))
599                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
600                 for shnum,share_hash in share_hash_chain.items():
601                     self.failUnless(isinstance(shnum, int))
602                     self.failUnless(isinstance(share_hash, str))
603                     self.failUnlessEqual(len(share_hash), 32)
604                 self.failUnless(isinstance(block_hash_tree, list))
605                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
606                 self.failUnlessEqual(IV, "SALT"*4)
607                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
608                 self.failUnlessEqual(enc_privkey, self._fn.get_encprivkey())
609         d.addCallback(_generated)
610         return d
611
612     # TODO: when we publish to 20 peers, we should get one share per peer on 10
613     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
614     # when we publish to zero peers, we should get a NotEnoughSharesError
615
616 class PublishMixin:
617     def publish_one(self):
618         # publish a file and create shares, which can then be manipulated
619         # later.
620         self.CONTENTS = "New contents go here" * 1000
621         self._storage = FakeStorage()
622         self._nodemaker = make_nodemaker(self._storage)
623         self._storage_broker = self._nodemaker.storage_broker
624         d = self._nodemaker.create_mutable_file(self.CONTENTS)
625         def _created(node):
626             self._fn = node
627             self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
628         d.addCallback(_created)
629         return d
630
631     def publish_multiple(self):
632         self.CONTENTS = ["Contents 0",
633                          "Contents 1",
634                          "Contents 2",
635                          "Contents 3a",
636                          "Contents 3b"]
637         self._copied_shares = {}
638         self._storage = FakeStorage()
639         self._nodemaker = make_nodemaker(self._storage)
640         d = self._nodemaker.create_mutable_file(self.CONTENTS[0]) # seqnum=1
641         def _created(node):
642             self._fn = node
643             # now create multiple versions of the same file, and accumulate
644             # their shares, so we can mix and match them later.
645             d = defer.succeed(None)
646             d.addCallback(self._copy_shares, 0)
647             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
648             d.addCallback(self._copy_shares, 1)
649             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
650             d.addCallback(self._copy_shares, 2)
651             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
652             d.addCallback(self._copy_shares, 3)
653             # now we replace all the shares with version s3, and upload a new
654             # version to get s4b.
655             rollback = dict([(i,2) for i in range(10)])
656             d.addCallback(lambda res: self._set_versions(rollback))
657             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
658             d.addCallback(self._copy_shares, 4)
659             # we leave the storage in state 4
660             return d
661         d.addCallback(_created)
662         return d
663
664     def _copy_shares(self, ignored, index):
665         shares = self._storage._peers
666         # we need a deep copy
667         new_shares = {}
668         for peerid in shares:
669             new_shares[peerid] = {}
670             for shnum in shares[peerid]:
671                 new_shares[peerid][shnum] = shares[peerid][shnum]
672         self._copied_shares[index] = new_shares
673
674     def _set_versions(self, versionmap):
675         # versionmap maps shnums to which version (0,1,2,3,4) we want the
676         # share to be at. Any shnum which is left out of the map will stay at
677         # its current version.
678         shares = self._storage._peers
679         oldshares = self._copied_shares
680         for peerid in shares:
681             for shnum in shares[peerid]:
682                 if shnum in versionmap:
683                     index = versionmap[shnum]
684                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
685
686
687 class Servermap(unittest.TestCase, PublishMixin):
688     def setUp(self):
689         return self.publish_one()
690
691     def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None):
692         if fn is None:
693             fn = self._fn
694         if sb is None:
695             sb = self._storage_broker
696         smu = ServermapUpdater(fn, sb, Monitor(),
697                                ServerMap(), mode)
698         d = smu.update()
699         return d
700
701     def update_servermap(self, oldmap, mode=MODE_CHECK):
702         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
703                                oldmap, mode)
704         d = smu.update()
705         return d
706
707     def failUnlessOneRecoverable(self, sm, num_shares):
708         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
709         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
710         best = sm.best_recoverable_version()
711         self.failIfEqual(best, None)
712         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
713         self.failUnlessEqual(len(sm.shares_available()), 1)
714         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
715         shnum, peerids = sm.make_sharemap().items()[0]
716         peerid = list(peerids)[0]
717         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
718         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
719         return sm
720
721     def test_basic(self):
722         d = defer.succeed(None)
723         ms = self.make_servermap
724         us = self.update_servermap
725
726         d.addCallback(lambda res: ms(mode=MODE_CHECK))
727         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
728         d.addCallback(lambda res: ms(mode=MODE_WRITE))
729         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
730         d.addCallback(lambda res: ms(mode=MODE_READ))
731         # this mode stops at k+epsilon, and epsilon=k, so 6 shares
732         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
733         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
734         # this mode stops at 'k' shares
735         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
736
737         # and can we re-use the same servermap? Note that these are sorted in
738         # increasing order of number of servers queried, since once a server
739         # gets into the servermap, we'll always ask it for an update.
740         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
741         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
742         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
743         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
744         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
745         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
746         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
747         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
748         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
749
750         return d
751
752     def test_fetch_privkey(self):
753         d = defer.succeed(None)
754         # use the sibling filenode (which hasn't been used yet), and make
755         # sure it can fetch the privkey. The file is small, so the privkey
756         # will be fetched on the first (query) pass.
757         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
758         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
759
760         # create a new file, which is large enough to knock the privkey out
761         # of the early part of the file
762         LARGE = "These are Larger contents" * 200 # about 5KB
763         d.addCallback(lambda res: self._nodemaker.create_mutable_file(LARGE))
764         def _created(large_fn):
765             large_fn2 = self._nodemaker.create_from_cap(large_fn.get_uri())
766             return self.make_servermap(MODE_WRITE, large_fn2)
767         d.addCallback(_created)
768         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
769         return d
770
771     def test_mark_bad(self):
772         d = defer.succeed(None)
773         ms = self.make_servermap
774
775         d.addCallback(lambda res: ms(mode=MODE_READ))
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
777         def _made_map(sm):
778             v = sm.best_recoverable_version()
779             vm = sm.make_versionmap()
780             shares = list(vm[v])
781             self.failUnlessEqual(len(shares), 6)
782             self._corrupted = set()
783             # mark the first 5 shares as corrupt, then update the servermap.
784             # The map should not have the marked shares it in any more, and
785             # new shares should be found to replace the missing ones.
786             for (shnum, peerid, timestamp) in shares:
787                 if shnum < 5:
788                     self._corrupted.add( (peerid, shnum) )
789                     sm.mark_bad_share(peerid, shnum, "")
790             return self.update_servermap(sm, MODE_WRITE)
791         d.addCallback(_made_map)
792         def _check_map(sm):
793             # this should find all 5 shares that weren't marked bad
794             v = sm.best_recoverable_version()
795             vm = sm.make_versionmap()
796             shares = list(vm[v])
797             for (peerid, shnum) in self._corrupted:
798                 peer_shares = sm.shares_on_peer(peerid)
799                 self.failIf(shnum in peer_shares,
800                             "%d was in %s" % (shnum, peer_shares))
801             self.failUnlessEqual(len(shares), 5)
802         d.addCallback(_check_map)
803         return d
804
805     def failUnlessNoneRecoverable(self, sm):
806         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
807         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
808         best = sm.best_recoverable_version()
809         self.failUnlessEqual(best, None)
810         self.failUnlessEqual(len(sm.shares_available()), 0)
811
812     def test_no_shares(self):
813         self._storage._peers = {} # delete all shares
814         ms = self.make_servermap
815         d = defer.succeed(None)
816
817         d.addCallback(lambda res: ms(mode=MODE_CHECK))
818         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
819
820         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
821         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
822
823         d.addCallback(lambda res: ms(mode=MODE_WRITE))
824         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
825
826         d.addCallback(lambda res: ms(mode=MODE_READ))
827         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
828
829         return d
830
831     def failUnlessNotQuiteEnough(self, sm):
832         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
833         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
834         best = sm.best_recoverable_version()
835         self.failUnlessEqual(best, None)
836         self.failUnlessEqual(len(sm.shares_available()), 1)
837         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
838         return sm
839
840     def test_not_quite_enough_shares(self):
841         s = self._storage
842         ms = self.make_servermap
843         num_shares = len(s._peers)
844         for peerid in s._peers:
845             s._peers[peerid] = {}
846             num_shares -= 1
847             if num_shares == 2:
848                 break
849         # now there ought to be only two shares left
850         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
851
852         d = defer.succeed(None)
853
854         d.addCallback(lambda res: ms(mode=MODE_CHECK))
855         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
856         d.addCallback(lambda sm:
857                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
858         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
859         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
860         d.addCallback(lambda res: ms(mode=MODE_WRITE))
861         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
862         d.addCallback(lambda res: ms(mode=MODE_READ))
863         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
864
865         return d
866
867
868
869 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
870     def setUp(self):
871         return self.publish_one()
872
873     def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None):
874         if oldmap is None:
875             oldmap = ServerMap()
876         if sb is None:
877             sb = self._storage_broker
878         smu = ServermapUpdater(self._fn, sb, Monitor(), oldmap, mode)
879         d = smu.update()
880         return d
881
882     def abbrev_verinfo(self, verinfo):
883         if verinfo is None:
884             return None
885         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
886          offsets_tuple) = verinfo
887         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
888
889     def abbrev_verinfo_dict(self, verinfo_d):
890         output = {}
891         for verinfo,value in verinfo_d.items():
892             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
893              offsets_tuple) = verinfo
894             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
895         return output
896
897     def dump_servermap(self, servermap):
898         print "SERVERMAP", servermap
899         print "RECOVERABLE", [self.abbrev_verinfo(v)
900                               for v in servermap.recoverable_versions()]
901         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
902         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
903
904     def do_download(self, servermap, version=None):
905         if version is None:
906             version = servermap.best_recoverable_version()
907         r = Retrieve(self._fn, servermap, version)
908         return r.download()
909
910     def test_basic(self):
911         d = self.make_servermap()
912         def _do_retrieve(servermap):
913             self._smap = servermap
914             #self.dump_servermap(servermap)
915             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
916             return self.do_download(servermap)
917         d.addCallback(_do_retrieve)
918         def _retrieved(new_contents):
919             self.failUnlessEqual(new_contents, self.CONTENTS)
920         d.addCallback(_retrieved)
921         # we should be able to re-use the same servermap, both with and
922         # without updating it.
923         d.addCallback(lambda res: self.do_download(self._smap))
924         d.addCallback(_retrieved)
925         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
926         d.addCallback(lambda res: self.do_download(self._smap))
927         d.addCallback(_retrieved)
928         # clobbering the pubkey should make the servermap updater re-fetch it
929         def _clobber_pubkey(res):
930             self._fn._pubkey = None
931         d.addCallback(_clobber_pubkey)
932         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
933         d.addCallback(lambda res: self.do_download(self._smap))
934         d.addCallback(_retrieved)
935         return d
936
937     def test_all_shares_vanished(self):
938         d = self.make_servermap()
939         def _remove_shares(servermap):
940             for shares in self._storage._peers.values():
941                 shares.clear()
942             d1 = self.shouldFail(NotEnoughSharesError,
943                                  "test_all_shares_vanished",
944                                  "ran out of peers",
945                                  self.do_download, servermap)
946             return d1
947         d.addCallback(_remove_shares)
948         return d
949
950     def test_no_servers(self):
951         sb2 = make_storagebroker(num_peers=0)
952         # if there are no servers, then a MODE_READ servermap should come
953         # back empty
954         d = self.make_servermap(sb=sb2)
955         def _check_servermap(servermap):
956             self.failUnlessEqual(servermap.best_recoverable_version(), None)
957             self.failIf(servermap.recoverable_versions())
958             self.failIf(servermap.unrecoverable_versions())
959             self.failIf(servermap.all_peers())
960         d.addCallback(_check_servermap)
961         return d
962     test_no_servers.timeout = 15
963
964     def test_no_servers_download(self):
965         sb2 = make_storagebroker(num_peers=0)
966         self._fn._storage_broker = sb2
967         d = self.shouldFail(UnrecoverableFileError,
968                             "test_no_servers_download",
969                             "no recoverable versions",
970                             self._fn.download_best_version)
971         def _restore(res):
972             # a failed download that occurs while we aren't connected to
973             # anybody should not prevent a subsequent download from working.
974             # This isn't quite the webapi-driven test that #463 wants, but it
975             # should be close enough.
976             self._fn._storage_broker = self._storage_broker
977             return self._fn.download_best_version()
978         def _retrieved(new_contents):
979             self.failUnlessEqual(new_contents, self.CONTENTS)
980         d.addCallback(_restore)
981         d.addCallback(_retrieved)
982         return d
983     test_no_servers_download.timeout = 15
984
985     def _test_corrupt_all(self, offset, substring,
986                           should_succeed=False, corrupt_early=True,
987                           failure_checker=None):
988         d = defer.succeed(None)
989         if corrupt_early:
990             d.addCallback(corrupt, self._storage, offset)
991         d.addCallback(lambda res: self.make_servermap())
992         if not corrupt_early:
993             d.addCallback(corrupt, self._storage, offset)
994         def _do_retrieve(servermap):
995             ver = servermap.best_recoverable_version()
996             if ver is None and not should_succeed:
997                 # no recoverable versions == not succeeding. The problem
998                 # should be noted in the servermap's list of problems.
999                 if substring:
1000                     allproblems = [str(f) for f in servermap.problems]
1001                     self.failUnlessIn(substring, "".join(allproblems))
1002                 return servermap
1003             if should_succeed:
1004                 d1 = self._fn.download_version(servermap, ver)
1005                 d1.addCallback(lambda new_contents:
1006                                self.failUnlessEqual(new_contents, self.CONTENTS))
1007             else:
1008                 d1 = self.shouldFail(NotEnoughSharesError,
1009                                      "_corrupt_all(offset=%s)" % (offset,),
1010                                      substring,
1011                                      self._fn.download_version, servermap, ver)
1012             if failure_checker:
1013                 d1.addCallback(failure_checker)
1014             d1.addCallback(lambda res: servermap)
1015             return d1
1016         d.addCallback(_do_retrieve)
1017         return d
1018
1019     def test_corrupt_all_verbyte(self):
1020         # when the version byte is not 0, we hit an UnknownVersionError error
1021         # in unpack_share().
1022         d = self._test_corrupt_all(0, "UnknownVersionError")
1023         def _check_servermap(servermap):
1024             # and the dump should mention the problems
1025             s = StringIO()
1026             dump = servermap.dump(s).getvalue()
1027             self.failUnless("10 PROBLEMS" in dump, dump)
1028         d.addCallback(_check_servermap)
1029         return d
1030
1031     def test_corrupt_all_seqnum(self):
1032         # a corrupt sequence number will trigger a bad signature
1033         return self._test_corrupt_all(1, "signature is invalid")
1034
1035     def test_corrupt_all_R(self):
1036         # a corrupt root hash will trigger a bad signature
1037         return self._test_corrupt_all(9, "signature is invalid")
1038
1039     def test_corrupt_all_IV(self):
1040         # a corrupt salt/IV will trigger a bad signature
1041         return self._test_corrupt_all(41, "signature is invalid")
1042
1043     def test_corrupt_all_k(self):
1044         # a corrupt 'k' will trigger a bad signature
1045         return self._test_corrupt_all(57, "signature is invalid")
1046
1047     def test_corrupt_all_N(self):
1048         # a corrupt 'N' will trigger a bad signature
1049         return self._test_corrupt_all(58, "signature is invalid")
1050
1051     def test_corrupt_all_segsize(self):
1052         # a corrupt segsize will trigger a bad signature
1053         return self._test_corrupt_all(59, "signature is invalid")
1054
1055     def test_corrupt_all_datalen(self):
1056         # a corrupt data length will trigger a bad signature
1057         return self._test_corrupt_all(67, "signature is invalid")
1058
1059     def test_corrupt_all_pubkey(self):
1060         # a corrupt pubkey won't match the URI's fingerprint. We need to
1061         # remove the pubkey from the filenode, or else it won't bother trying
1062         # to update it.
1063         self._fn._pubkey = None
1064         return self._test_corrupt_all("pubkey",
1065                                       "pubkey doesn't match fingerprint")
1066
1067     def test_corrupt_all_sig(self):
1068         # a corrupt signature is a bad one
1069         # the signature runs from about [543:799], depending upon the length
1070         # of the pubkey
1071         return self._test_corrupt_all("signature", "signature is invalid")
1072
1073     def test_corrupt_all_share_hash_chain_number(self):
1074         # a corrupt share hash chain entry will show up as a bad hash. If we
1075         # mangle the first byte, that will look like a bad hash number,
1076         # causing an IndexError
1077         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1078
1079     def test_corrupt_all_share_hash_chain_hash(self):
1080         # a corrupt share hash chain entry will show up as a bad hash. If we
1081         # mangle a few bytes in, that will look like a bad hash.
1082         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1083
1084     def test_corrupt_all_block_hash_tree(self):
1085         return self._test_corrupt_all("block_hash_tree",
1086                                       "block hash tree failure")
1087
1088     def test_corrupt_all_block(self):
1089         return self._test_corrupt_all("share_data", "block hash tree failure")
1090
1091     def test_corrupt_all_encprivkey(self):
1092         # a corrupted privkey won't even be noticed by the reader, only by a
1093         # writer.
1094         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1095
1096
1097     def test_corrupt_all_seqnum_late(self):
1098         # corrupting the seqnum between mapupdate and retrieve should result
1099         # in NotEnoughSharesError, since each share will look invalid
1100         def _check(res):
1101             f = res[0]
1102             self.failUnless(f.check(NotEnoughSharesError))
1103             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1104         return self._test_corrupt_all(1, "ran out of peers",
1105                                       corrupt_early=False,
1106                                       failure_checker=_check)
1107
1108     def test_corrupt_all_block_hash_tree_late(self):
1109         def _check(res):
1110             f = res[0]
1111             self.failUnless(f.check(NotEnoughSharesError))
1112         return self._test_corrupt_all("block_hash_tree",
1113                                       "block hash tree failure",
1114                                       corrupt_early=False,
1115                                       failure_checker=_check)
1116
1117
1118     def test_corrupt_all_block_late(self):
1119         def _check(res):
1120             f = res[0]
1121             self.failUnless(f.check(NotEnoughSharesError))
1122         return self._test_corrupt_all("share_data", "block hash tree failure",
1123                                       corrupt_early=False,
1124                                       failure_checker=_check)
1125
1126
1127     def test_basic_pubkey_at_end(self):
1128         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1129         # download to succeed but forcing a bunch of retries first. Note that
1130         # this is rather pessimistic: our Retrieve process will throw away
1131         # the whole share if the pubkey is bad, even though the rest of the
1132         # share might be good.
1133
1134         self._fn._pubkey = None
1135         k = self._fn.get_required_shares()
1136         N = self._fn.get_total_shares()
1137         d = defer.succeed(None)
1138         d.addCallback(corrupt, self._storage, "pubkey",
1139                       shnums_to_corrupt=range(0, N-k))
1140         d.addCallback(lambda res: self.make_servermap())
1141         def _do_retrieve(servermap):
1142             self.failUnless(servermap.problems)
1143             self.failUnless("pubkey doesn't match fingerprint"
1144                             in str(servermap.problems[0]))
1145             ver = servermap.best_recoverable_version()
1146             r = Retrieve(self._fn, servermap, ver)
1147             return r.download()
1148         d.addCallback(_do_retrieve)
1149         d.addCallback(lambda new_contents:
1150                       self.failUnlessEqual(new_contents, self.CONTENTS))
1151         return d
1152
1153     def test_corrupt_some(self):
1154         # corrupt the data of first five shares (so the servermap thinks
1155         # they're good but retrieve marks them as bad), so that the
1156         # MODE_READ set of 6 will be insufficient, forcing node.download to
1157         # retry with more servers.
1158         corrupt(None, self._storage, "share_data", range(5))
1159         d = self.make_servermap()
1160         def _do_retrieve(servermap):
1161             ver = servermap.best_recoverable_version()
1162             self.failUnless(ver)
1163             return self._fn.download_best_version()
1164         d.addCallback(_do_retrieve)
1165         d.addCallback(lambda new_contents:
1166                       self.failUnlessEqual(new_contents, self.CONTENTS))
1167         return d
1168
1169     def test_download_fails(self):
1170         corrupt(None, self._storage, "signature")
1171         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1172                             "no recoverable versions",
1173                             self._fn.download_best_version)
1174         return d
1175
1176
1177 class CheckerMixin:
1178     def check_good(self, r, where):
1179         self.failUnless(r.is_healthy(), where)
1180         return r
1181
1182     def check_bad(self, r, where):
1183         self.failIf(r.is_healthy(), where)
1184         return r
1185
1186     def check_expected_failure(self, r, expected_exception, substring, where):
1187         for (peerid, storage_index, shnum, f) in r.problems:
1188             if f.check(expected_exception):
1189                 self.failUnless(substring in str(f),
1190                                 "%s: substring '%s' not in '%s'" %
1191                                 (where, substring, str(f)))
1192                 return
1193         self.fail("%s: didn't see expected exception %s in problems %s" %
1194                   (where, expected_exception, r.problems))
1195
1196
1197 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1198     def setUp(self):
1199         return self.publish_one()
1200
1201
1202     def test_check_good(self):
1203         d = self._fn.check(Monitor())
1204         d.addCallback(self.check_good, "test_check_good")
1205         return d
1206
1207     def test_check_no_shares(self):
1208         for shares in self._storage._peers.values():
1209             shares.clear()
1210         d = self._fn.check(Monitor())
1211         d.addCallback(self.check_bad, "test_check_no_shares")
1212         return d
1213
1214     def test_check_not_enough_shares(self):
1215         for shares in self._storage._peers.values():
1216             for shnum in shares.keys():
1217                 if shnum > 0:
1218                     del shares[shnum]
1219         d = self._fn.check(Monitor())
1220         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1221         return d
1222
1223     def test_check_all_bad_sig(self):
1224         corrupt(None, self._storage, 1) # bad sig
1225         d = self._fn.check(Monitor())
1226         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1227         return d
1228
1229     def test_check_all_bad_blocks(self):
1230         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1231         # the Checker won't notice this.. it doesn't look at actual data
1232         d = self._fn.check(Monitor())
1233         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1234         return d
1235
1236     def test_verify_good(self):
1237         d = self._fn.check(Monitor(), verify=True)
1238         d.addCallback(self.check_good, "test_verify_good")
1239         return d
1240
1241     def test_verify_all_bad_sig(self):
1242         corrupt(None, self._storage, 1) # bad sig
1243         d = self._fn.check(Monitor(), verify=True)
1244         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1245         return d
1246
1247     def test_verify_one_bad_sig(self):
1248         corrupt(None, self._storage, 1, [9]) # bad sig
1249         d = self._fn.check(Monitor(), verify=True)
1250         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1251         return d
1252
1253     def test_verify_one_bad_block(self):
1254         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1255         # the Verifier *will* notice this, since it examines every byte
1256         d = self._fn.check(Monitor(), verify=True)
1257         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1258         d.addCallback(self.check_expected_failure,
1259                       CorruptShareError, "block hash tree failure",
1260                       "test_verify_one_bad_block")
1261         return d
1262
1263     def test_verify_one_bad_sharehash(self):
1264         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1265         d = self._fn.check(Monitor(), verify=True)
1266         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1267         d.addCallback(self.check_expected_failure,
1268                       CorruptShareError, "corrupt hashes",
1269                       "test_verify_one_bad_sharehash")
1270         return d
1271
1272     def test_verify_one_bad_encprivkey(self):
1273         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1274         d = self._fn.check(Monitor(), verify=True)
1275         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1276         d.addCallback(self.check_expected_failure,
1277                       CorruptShareError, "invalid privkey",
1278                       "test_verify_one_bad_encprivkey")
1279         return d
1280
1281     def test_verify_one_bad_encprivkey_uncheckable(self):
1282         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1283         readonly_fn = self._fn.get_readonly()
1284         # a read-only node has no way to validate the privkey
1285         d = readonly_fn.check(Monitor(), verify=True)
1286         d.addCallback(self.check_good,
1287                       "test_verify_one_bad_encprivkey_uncheckable")
1288         return d
1289
1290 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1291
1292     def get_shares(self, s):
1293         all_shares = {} # maps (peerid, shnum) to share data
1294         for peerid in s._peers:
1295             shares = s._peers[peerid]
1296             for shnum in shares:
1297                 data = shares[shnum]
1298                 all_shares[ (peerid, shnum) ] = data
1299         return all_shares
1300
1301     def copy_shares(self, ignored=None):
1302         self.old_shares.append(self.get_shares(self._storage))
1303
1304     def test_repair_nop(self):
1305         self.old_shares = []
1306         d = self.publish_one()
1307         d.addCallback(self.copy_shares)
1308         d.addCallback(lambda res: self._fn.check(Monitor()))
1309         d.addCallback(lambda check_results: self._fn.repair(check_results))
1310         def _check_results(rres):
1311             self.failUnless(IRepairResults.providedBy(rres))
1312             self.failUnless(rres.get_successful())
1313             # TODO: examine results
1314
1315             self.copy_shares()
1316
1317             initial_shares = self.old_shares[0]
1318             new_shares = self.old_shares[1]
1319             # TODO: this really shouldn't change anything. When we implement
1320             # a "minimal-bandwidth" repairer", change this test to assert:
1321             #self.failUnlessEqual(new_shares, initial_shares)
1322
1323             # all shares should be in the same place as before
1324             self.failUnlessEqual(set(initial_shares.keys()),
1325                                  set(new_shares.keys()))
1326             # but they should all be at a newer seqnum. The IV will be
1327             # different, so the roothash will be too.
1328             for key in initial_shares:
1329                 (version0,
1330                  seqnum0,
1331                  root_hash0,
1332                  IV0,
1333                  k0, N0, segsize0, datalen0,
1334                  o0) = unpack_header(initial_shares[key])
1335                 (version1,
1336                  seqnum1,
1337                  root_hash1,
1338                  IV1,
1339                  k1, N1, segsize1, datalen1,
1340                  o1) = unpack_header(new_shares[key])
1341                 self.failUnlessEqual(version0, version1)
1342                 self.failUnlessEqual(seqnum0+1, seqnum1)
1343                 self.failUnlessEqual(k0, k1)
1344                 self.failUnlessEqual(N0, N1)
1345                 self.failUnlessEqual(segsize0, segsize1)
1346                 self.failUnlessEqual(datalen0, datalen1)
1347         d.addCallback(_check_results)
1348         return d
1349
1350     def failIfSharesChanged(self, ignored=None):
1351         old_shares = self.old_shares[-2]
1352         current_shares = self.old_shares[-1]
1353         self.failUnlessEqual(old_shares, current_shares)
1354
1355     def test_unrepairable_0shares(self):
1356         d = self.publish_one()
1357         def _delete_all_shares(ign):
1358             shares = self._storage._peers
1359             for peerid in shares:
1360                 shares[peerid] = {}
1361         d.addCallback(_delete_all_shares)
1362         d.addCallback(lambda ign: self._fn.check(Monitor()))
1363         d.addCallback(lambda check_results: self._fn.repair(check_results))
1364         def _check(crr):
1365             self.failUnlessEqual(crr.get_successful(), False)
1366         d.addCallback(_check)
1367         return d
1368
1369     def test_unrepairable_1share(self):
1370         d = self.publish_one()
1371         def _delete_all_shares(ign):
1372             shares = self._storage._peers
1373             for peerid in shares:
1374                 for shnum in list(shares[peerid]):
1375                     if shnum > 0:
1376                         del shares[peerid][shnum]
1377         d.addCallback(_delete_all_shares)
1378         d.addCallback(lambda ign: self._fn.check(Monitor()))
1379         d.addCallback(lambda check_results: self._fn.repair(check_results))
1380         def _check(crr):
1381             self.failUnlessEqual(crr.get_successful(), False)
1382         d.addCallback(_check)
1383         return d
1384
1385     def test_merge(self):
1386         self.old_shares = []
1387         d = self.publish_multiple()
1388         # repair will refuse to merge multiple highest seqnums unless you
1389         # pass force=True
1390         d.addCallback(lambda res:
1391                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1392                                           1:4,3:4,5:4,7:4,9:4}))
1393         d.addCallback(self.copy_shares)
1394         d.addCallback(lambda res: self._fn.check(Monitor()))
1395         def _try_repair(check_results):
1396             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1397             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1398                                  self._fn.repair, check_results)
1399             d2.addCallback(self.copy_shares)
1400             d2.addCallback(self.failIfSharesChanged)
1401             d2.addCallback(lambda res: check_results)
1402             return d2
1403         d.addCallback(_try_repair)
1404         d.addCallback(lambda check_results:
1405                       self._fn.repair(check_results, force=True))
1406         # this should give us 10 shares of the highest roothash
1407         def _check_repair_results(rres):
1408             self.failUnless(rres.get_successful())
1409             pass # TODO
1410         d.addCallback(_check_repair_results)
1411         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1412         def _check_smap(smap):
1413             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1414             self.failIf(smap.unrecoverable_versions())
1415             # now, which should have won?
1416             roothash_s4a = self.get_roothash_for(3)
1417             roothash_s4b = self.get_roothash_for(4)
1418             if roothash_s4b > roothash_s4a:
1419                 expected_contents = self.CONTENTS[4]
1420             else:
1421                 expected_contents = self.CONTENTS[3]
1422             new_versionid = smap.best_recoverable_version()
1423             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1424             d2 = self._fn.download_version(smap, new_versionid)
1425             d2.addCallback(self.failUnlessEqual, expected_contents)
1426             return d2
1427         d.addCallback(_check_smap)
1428         return d
1429
1430     def test_non_merge(self):
1431         self.old_shares = []
1432         d = self.publish_multiple()
1433         # repair should not refuse a repair that doesn't need to merge. In
1434         # this case, we combine v2 with v3. The repair should ignore v2 and
1435         # copy v3 into a new v5.
1436         d.addCallback(lambda res:
1437                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1438                                           1:3,3:3,5:3,7:3,9:3}))
1439         d.addCallback(lambda res: self._fn.check(Monitor()))
1440         d.addCallback(lambda check_results: self._fn.repair(check_results))
1441         # this should give us 10 shares of v3
1442         def _check_repair_results(rres):
1443             self.failUnless(rres.get_successful())
1444             pass # TODO
1445         d.addCallback(_check_repair_results)
1446         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1447         def _check_smap(smap):
1448             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1449             self.failIf(smap.unrecoverable_versions())
1450             # now, which should have won?
1451             expected_contents = self.CONTENTS[3]
1452             new_versionid = smap.best_recoverable_version()
1453             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1454             d2 = self._fn.download_version(smap, new_versionid)
1455             d2.addCallback(self.failUnlessEqual, expected_contents)
1456             return d2
1457         d.addCallback(_check_smap)
1458         return d
1459
1460     def get_roothash_for(self, index):
1461         # return the roothash for the first share we see in the saved set
1462         shares = self._copied_shares[index]
1463         for peerid in shares:
1464             for shnum in shares[peerid]:
1465                 share = shares[peerid][shnum]
1466                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1467                           unpack_header(share)
1468                 return root_hash
1469
1470     def test_check_and_repair_readcap(self):
1471         # we can't currently repair from a mutable readcap: #625
1472         self.old_shares = []
1473         d = self.publish_one()
1474         d.addCallback(self.copy_shares)
1475         def _get_readcap(res):
1476             self._fn3 = self._fn.get_readonly()
1477             # also delete some shares
1478             for peerid,shares in self._storage._peers.items():
1479                 shares.pop(0, None)
1480         d.addCallback(_get_readcap)
1481         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1482         def _check_results(crr):
1483             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1484             # we should detect the unhealthy, but skip over mutable-readcap
1485             # repairs until #625 is fixed
1486             self.failIf(crr.get_pre_repair_results().is_healthy())
1487             self.failIf(crr.get_repair_attempted())
1488             self.failIf(crr.get_post_repair_results().is_healthy())
1489         d.addCallback(_check_results)
1490         return d
1491
1492 class DevNullDictionary(dict):
1493     def __setitem__(self, key, value):
1494         return
1495
1496 class MultipleEncodings(unittest.TestCase):
1497     def setUp(self):
1498         self.CONTENTS = "New contents go here"
1499         self._storage = FakeStorage()
1500         self._nodemaker = make_nodemaker(self._storage, num_peers=20)
1501         self._storage_broker = self._nodemaker.storage_broker
1502         d = self._nodemaker.create_mutable_file(self.CONTENTS)
1503         def _created(node):
1504             self._fn = node
1505         d.addCallback(_created)
1506         return d
1507
1508     def _encode(self, k, n, data):
1509         # encode 'data' into a peerid->shares dict.
1510
1511         fn = self._fn
1512         # disable the nodecache, since for these tests we explicitly need
1513         # multiple nodes pointing at the same file
1514         self._nodemaker._node_cache = DevNullDictionary()
1515         fn2 = self._nodemaker.create_from_cap(fn.get_uri())
1516         # then we copy over other fields that are normally fetched from the
1517         # existing shares
1518         fn2._pubkey = fn._pubkey
1519         fn2._privkey = fn._privkey
1520         fn2._encprivkey = fn._encprivkey
1521         # and set the encoding parameters to something completely different
1522         fn2._required_shares = k
1523         fn2._total_shares = n
1524
1525         s = self._storage
1526         s._peers = {} # clear existing storage
1527         p2 = Publish(fn2, self._storage_broker, None)
1528         d = p2.publish(data)
1529         def _published(res):
1530             shares = s._peers
1531             s._peers = {}
1532             return shares
1533         d.addCallback(_published)
1534         return d
1535
1536     def make_servermap(self, mode=MODE_READ, oldmap=None):
1537         if oldmap is None:
1538             oldmap = ServerMap()
1539         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
1540                                oldmap, mode)
1541         d = smu.update()
1542         return d
1543
1544     def test_multiple_encodings(self):
1545         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1546         # then mix up the shares, to make sure that download survives seeing
1547         # a variety of encodings. This is actually kind of tricky to set up.
1548
1549         contents1 = "Contents for encoding 1 (3-of-10) go here"
1550         contents2 = "Contents for encoding 2 (4-of-9) go here"
1551         contents3 = "Contents for encoding 3 (4-of-7) go here"
1552
1553         # we make a retrieval object that doesn't know what encoding
1554         # parameters to use
1555         fn3 = self._nodemaker.create_from_cap(self._fn.get_uri())
1556
1557         # now we upload a file through fn1, and grab its shares
1558         d = self._encode(3, 10, contents1)
1559         def _encoded_1(shares):
1560             self._shares1 = shares
1561         d.addCallback(_encoded_1)
1562         d.addCallback(lambda res: self._encode(4, 9, contents2))
1563         def _encoded_2(shares):
1564             self._shares2 = shares
1565         d.addCallback(_encoded_2)
1566         d.addCallback(lambda res: self._encode(4, 7, contents3))
1567         def _encoded_3(shares):
1568             self._shares3 = shares
1569         d.addCallback(_encoded_3)
1570
1571         def _merge(res):
1572             log.msg("merging sharelists")
1573             # we merge the shares from the two sets, leaving each shnum in
1574             # its original location, but using a share from set1 or set2
1575             # according to the following sequence:
1576             #
1577             #  4-of-9  a  s2
1578             #  4-of-9  b  s2
1579             #  4-of-7  c   s3
1580             #  4-of-9  d  s2
1581             #  3-of-9  e s1
1582             #  3-of-9  f s1
1583             #  3-of-9  g s1
1584             #  4-of-9  h  s2
1585             #
1586             # so that neither form can be recovered until fetch [f], at which
1587             # point version-s1 (the 3-of-10 form) should be recoverable. If
1588             # the implementation latches on to the first version it sees,
1589             # then s2 will be recoverable at fetch [g].
1590
1591             # Later, when we implement code that handles multiple versions,
1592             # we can use this framework to assert that all recoverable
1593             # versions are retrieved, and test that 'epsilon' does its job
1594
1595             places = [2, 2, 3, 2, 1, 1, 1, 2]
1596
1597             sharemap = {}
1598             sb = self._storage_broker
1599
1600             for peerid in sorted(sb.get_all_serverids()):
1601                 for shnum in self._shares1.get(peerid, {}):
1602                     if shnum < len(places):
1603                         which = places[shnum]
1604                     else:
1605                         which = "x"
1606                     self._storage._peers[peerid] = peers = {}
1607                     in_1 = shnum in self._shares1[peerid]
1608                     in_2 = shnum in self._shares2.get(peerid, {})
1609                     in_3 = shnum in self._shares3.get(peerid, {})
1610                     if which == 1:
1611                         if in_1:
1612                             peers[shnum] = self._shares1[peerid][shnum]
1613                             sharemap[shnum] = peerid
1614                     elif which == 2:
1615                         if in_2:
1616                             peers[shnum] = self._shares2[peerid][shnum]
1617                             sharemap[shnum] = peerid
1618                     elif which == 3:
1619                         if in_3:
1620                             peers[shnum] = self._shares3[peerid][shnum]
1621                             sharemap[shnum] = peerid
1622
1623             # we don't bother placing any other shares
1624             # now sort the sequence so that share 0 is returned first
1625             new_sequence = [sharemap[shnum]
1626                             for shnum in sorted(sharemap.keys())]
1627             self._storage._sequence = new_sequence
1628             log.msg("merge done")
1629         d.addCallback(_merge)
1630         d.addCallback(lambda res: fn3.download_best_version())
1631         def _retrieved(new_contents):
1632             # the current specified behavior is "first version recoverable"
1633             self.failUnlessEqual(new_contents, contents1)
1634         d.addCallback(_retrieved)
1635         return d
1636
1637
1638 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1639
1640     def setUp(self):
1641         return self.publish_multiple()
1642
1643     def test_multiple_versions(self):
1644         # if we see a mix of versions in the grid, download_best_version
1645         # should get the latest one
1646         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1647         d = self._fn.download_best_version()
1648         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1649         # and the checker should report problems
1650         d.addCallback(lambda res: self._fn.check(Monitor()))
1651         d.addCallback(self.check_bad, "test_multiple_versions")
1652
1653         # but if everything is at version 2, that's what we should download
1654         d.addCallback(lambda res:
1655                       self._set_versions(dict([(i,2) for i in range(10)])))
1656         d.addCallback(lambda res: self._fn.download_best_version())
1657         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1658         # if exactly one share is at version 3, we should still get v2
1659         d.addCallback(lambda res:
1660                       self._set_versions({0:3}))
1661         d.addCallback(lambda res: self._fn.download_best_version())
1662         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1663         # but the servermap should see the unrecoverable version. This
1664         # depends upon the single newer share being queried early.
1665         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1666         def _check_smap(smap):
1667             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1668             newer = smap.unrecoverable_newer_versions()
1669             self.failUnlessEqual(len(newer), 1)
1670             verinfo, health = newer.items()[0]
1671             self.failUnlessEqual(verinfo[0], 4)
1672             self.failUnlessEqual(health, (1,3))
1673             self.failIf(smap.needs_merge())
1674         d.addCallback(_check_smap)
1675         # if we have a mix of two parallel versions (s4a and s4b), we could
1676         # recover either
1677         d.addCallback(lambda res:
1678                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1679                                           1:4,3:4,5:4,7:4,9:4}))
1680         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1681         def _check_smap_mixed(smap):
1682             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1683             newer = smap.unrecoverable_newer_versions()
1684             self.failUnlessEqual(len(newer), 0)
1685             self.failUnless(smap.needs_merge())
1686         d.addCallback(_check_smap_mixed)
1687         d.addCallback(lambda res: self._fn.download_best_version())
1688         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1689                                                   res == self.CONTENTS[4]))
1690         return d
1691
1692     def test_replace(self):
1693         # if we see a mix of versions in the grid, we should be able to
1694         # replace them all with a newer version
1695
1696         # if exactly one share is at version 3, we should download (and
1697         # replace) v2, and the result should be v4. Note that the index we
1698         # give to _set_versions is different than the sequence number.
1699         target = dict([(i,2) for i in range(10)]) # seqnum3
1700         target[0] = 3 # seqnum4
1701         self._set_versions(target)
1702
1703         def _modify(oldversion, servermap, first_time):
1704             return oldversion + " modified"
1705         d = self._fn.modify(_modify)
1706         d.addCallback(lambda res: self._fn.download_best_version())
1707         expected = self.CONTENTS[2] + " modified"
1708         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1709         # and the servermap should indicate that the outlier was replaced too
1710         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1711         def _check_smap(smap):
1712             self.failUnlessEqual(smap.highest_seqnum(), 5)
1713             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1714             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1715         d.addCallback(_check_smap)
1716         return d
1717
1718
1719 class Utils(unittest.TestCase):
1720     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1721         # we compare this against sets of integers
1722         x = set(range(x_start, x_start+x_length))
1723         y = set(range(y_start, y_start+y_length))
1724         should_be_inside = x.issubset(y)
1725         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1726                                                          y_start, y_length),
1727                              str((x_start, x_length, y_start, y_length)))
1728
1729     def test_cache_inside(self):
1730         c = ResponseCache()
1731         x_start = 10
1732         x_length = 5
1733         for y_start in range(8, 17):
1734             for y_length in range(8):
1735                 self._do_inside(c, x_start, x_length, y_start, y_length)
1736
1737     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1738         # we compare this against sets of integers
1739         x = set(range(x_start, x_start+x_length))
1740         y = set(range(y_start, y_start+y_length))
1741         overlap = bool(x.intersection(y))
1742         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1743                                                       y_start, y_length),
1744                              str((x_start, x_length, y_start, y_length)))
1745
1746     def test_cache_overlap(self):
1747         c = ResponseCache()
1748         x_start = 10
1749         x_length = 5
1750         for y_start in range(8, 17):
1751             for y_length in range(8):
1752                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1753
1754     def test_cache(self):
1755         c = ResponseCache()
1756         # xdata = base62.b2a(os.urandom(100))[:100]
1757         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1758         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1759         nope = (None, None)
1760         c.add("v1", 1, 0, xdata, "time0")
1761         c.add("v1", 1, 2000, ydata, "time1")
1762         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1763         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1764         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1765         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1766         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1767         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1768         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1769         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1770         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1771         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1772         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1773         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1774         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1775         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1776         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1777         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1778         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1779         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1780
1781         # optional: join fragments
1782         c = ResponseCache()
1783         c.add("v1", 1, 0, xdata[:10], "time0")
1784         c.add("v1", 1, 10, xdata[10:20], "time1")
1785         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1786
1787 class Exceptions(unittest.TestCase):
1788     def test_repr(self):
1789         nmde = NeedMoreDataError(100, 50, 100)
1790         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1791         ucwe = UncoordinatedWriteError()
1792         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1793
1794 class SameKeyGenerator:
1795     def __init__(self, pubkey, privkey):
1796         self.pubkey = pubkey
1797         self.privkey = privkey
1798     def generate(self, keysize=None):
1799         return defer.succeed( (self.pubkey, self.privkey) )
1800
1801 class FirstServerGetsKilled:
1802     done = False
1803     def notify(self, retval, wrapper, methname):
1804         if not self.done:
1805             wrapper.broken = True
1806             self.done = True
1807         return retval
1808
1809 class FirstServerGetsDeleted:
1810     def __init__(self):
1811         self.done = False
1812         self.silenced = None
1813     def notify(self, retval, wrapper, methname):
1814         if not self.done:
1815             # this query will work, but later queries should think the share
1816             # has been deleted
1817             self.done = True
1818             self.silenced = wrapper
1819             return retval
1820         if wrapper == self.silenced:
1821             assert methname == "slot_testv_and_readv_and_writev"
1822             return (True, {})
1823         return retval
1824
1825 class Problems(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
1826     def test_publish_surprise(self):
1827         self.basedir = "mutable/Problems/test_publish_surprise"
1828         self.set_up_grid()
1829         nm = self.g.clients[0].nodemaker
1830         d = nm.create_mutable_file("contents 1")
1831         def _created(n):
1832             d = defer.succeed(None)
1833             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1834             def _got_smap1(smap):
1835                 # stash the old state of the file
1836                 self.old_map = smap
1837             d.addCallback(_got_smap1)
1838             # then modify the file, leaving the old map untouched
1839             d.addCallback(lambda res: log.msg("starting winning write"))
1840             d.addCallback(lambda res: n.overwrite("contents 2"))
1841             # now attempt to modify the file with the old servermap. This
1842             # will look just like an uncoordinated write, in which every
1843             # single share got updated between our mapupdate and our publish
1844             d.addCallback(lambda res: log.msg("starting doomed write"))
1845             d.addCallback(lambda res:
1846                           self.shouldFail(UncoordinatedWriteError,
1847                                           "test_publish_surprise", None,
1848                                           n.upload,
1849                                           "contents 2a", self.old_map))
1850             return d
1851         d.addCallback(_created)
1852         return d
1853
1854     def test_retrieve_surprise(self):
1855         self.basedir = "mutable/Problems/test_retrieve_surprise"
1856         self.set_up_grid()
1857         nm = self.g.clients[0].nodemaker
1858         d = nm.create_mutable_file("contents 1")
1859         def _created(n):
1860             d = defer.succeed(None)
1861             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1862             def _got_smap1(smap):
1863                 # stash the old state of the file
1864                 self.old_map = smap
1865             d.addCallback(_got_smap1)
1866             # then modify the file, leaving the old map untouched
1867             d.addCallback(lambda res: log.msg("starting winning write"))
1868             d.addCallback(lambda res: n.overwrite("contents 2"))
1869             # now attempt to retrieve the old version with the old servermap.
1870             # This will look like someone has changed the file since we
1871             # updated the servermap.
1872             d.addCallback(lambda res: n._cache._clear())
1873             d.addCallback(lambda res: log.msg("starting doomed read"))
1874             d.addCallback(lambda res:
1875                           self.shouldFail(NotEnoughSharesError,
1876                                           "test_retrieve_surprise",
1877                                           "ran out of peers: have 0 shares (k=3)",
1878                                           n.download_version,
1879                                           self.old_map,
1880                                           self.old_map.best_recoverable_version(),
1881                                           ))
1882             return d
1883         d.addCallback(_created)
1884         return d
1885
1886     def test_unexpected_shares(self):
1887         # upload the file, take a servermap, shut down one of the servers,
1888         # upload it again (causing shares to appear on a new server), then
1889         # upload using the old servermap. The last upload should fail with an
1890         # UncoordinatedWriteError, because of the shares that didn't appear
1891         # in the servermap.
1892         self.basedir = "mutable/Problems/test_unexpected_shares"
1893         self.set_up_grid()
1894         nm = self.g.clients[0].nodemaker
1895         d = nm.create_mutable_file("contents 1")
1896         def _created(n):
1897             d = defer.succeed(None)
1898             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1899             def _got_smap1(smap):
1900                 # stash the old state of the file
1901                 self.old_map = smap
1902                 # now shut down one of the servers
1903                 peer0 = list(smap.make_sharemap()[0])[0]
1904                 self.g.remove_server(peer0)
1905                 # then modify the file, leaving the old map untouched
1906                 log.msg("starting winning write")
1907                 return n.overwrite("contents 2")
1908             d.addCallback(_got_smap1)
1909             # now attempt to modify the file with the old servermap. This
1910             # will look just like an uncoordinated write, in which every
1911             # single share got updated between our mapupdate and our publish
1912             d.addCallback(lambda res: log.msg("starting doomed write"))
1913             d.addCallback(lambda res:
1914                           self.shouldFail(UncoordinatedWriteError,
1915                                           "test_surprise", None,
1916                                           n.upload,
1917                                           "contents 2a", self.old_map))
1918             return d
1919         d.addCallback(_created)
1920         return d
1921
1922     def test_bad_server(self):
1923         # Break one server, then create the file: the initial publish should
1924         # complete with an alternate server. Breaking a second server should
1925         # not prevent an update from succeeding either.
1926         self.basedir = "mutable/Problems/test_bad_server"
1927         self.set_up_grid()
1928         nm = self.g.clients[0].nodemaker
1929
1930         # to make sure that one of the initial peers is broken, we have to
1931         # get creative. We create an RSA key and compute its storage-index.
1932         # Then we make a KeyGenerator that always returns that one key, and
1933         # use it to create the mutable file. This will get easier when we can
1934         # use #467 static-server-selection to disable permutation and force
1935         # the choice of server for share[0].
1936
1937         d = nm.key_generator.generate(522)
1938         def _got_key( (pubkey, privkey) ):
1939             nm.key_generator = SameKeyGenerator(pubkey, privkey)
1940             pubkey_s = pubkey.serialize()
1941             privkey_s = privkey.serialize()
1942             u = uri.WriteableSSKFileURI(ssk_writekey_hash(privkey_s),
1943                                         ssk_pubkey_fingerprint_hash(pubkey_s))
1944             self._storage_index = u.get_storage_index()
1945         d.addCallback(_got_key)
1946         def _break_peer0(res):
1947             si = self._storage_index
1948             peerlist = nm.storage_broker.get_servers_for_index(si)
1949             peerid0, connection0 = peerlist[0]
1950             peerid1, connection1 = peerlist[1]
1951             connection0.broken = True
1952             self.connection1 = connection1
1953         d.addCallback(_break_peer0)
1954         # now "create" the file, using the pre-established key, and let the
1955         # initial publish finally happen
1956         d.addCallback(lambda res: nm.create_mutable_file("contents 1"))
1957         # that ought to work
1958         def _got_node(n):
1959             d = n.download_best_version()
1960             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1961             # now break the second peer
1962             def _break_peer1(res):
1963                 self.connection1.broken = True
1964             d.addCallback(_break_peer1)
1965             d.addCallback(lambda res: n.overwrite("contents 2"))
1966             # that ought to work too
1967             d.addCallback(lambda res: n.download_best_version())
1968             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1969             def _explain_error(f):
1970                 print f
1971                 if f.check(NotEnoughServersError):
1972                     print "first_error:", f.value.first_error
1973                 return f
1974             d.addErrback(_explain_error)
1975             return d
1976         d.addCallback(_got_node)
1977         return d
1978
1979     def test_bad_server_overlap(self):
1980         # like test_bad_server, but with no extra unused servers to fall back
1981         # upon. This means that we must re-use a server which we've already
1982         # used. If we don't remember the fact that we sent them one share
1983         # already, we'll mistakenly think we're experiencing an
1984         # UncoordinatedWriteError.
1985
1986         # Break one server, then create the file: the initial publish should
1987         # complete with an alternate server. Breaking a second server should
1988         # not prevent an update from succeeding either.
1989         self.basedir = "mutable/Problems/test_bad_server_overlap"
1990         self.set_up_grid()
1991         nm = self.g.clients[0].nodemaker
1992         sb = nm.storage_broker
1993
1994         peerids = [serverid for (serverid,ss) in sb.get_all_servers()]
1995         self.g.break_server(peerids[0])
1996
1997         d = nm.create_mutable_file("contents 1")
1998         def _created(n):
1999             d = n.download_best_version()
2000             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
2001             # now break one of the remaining servers
2002             def _break_second_server(res):
2003                 self.g.break_server(peerids[1])
2004             d.addCallback(_break_second_server)
2005             d.addCallback(lambda res: n.overwrite("contents 2"))
2006             # that ought to work too
2007             d.addCallback(lambda res: n.download_best_version())
2008             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
2009             return d
2010         d.addCallback(_created)
2011         return d
2012
2013     def test_publish_all_servers_bad(self):
2014         # Break all servers: the publish should fail
2015         self.basedir = "mutable/Problems/test_publish_all_servers_bad"
2016         self.set_up_grid()
2017         nm = self.g.clients[0].nodemaker
2018         for (serverid,ss) in nm.storage_broker.get_all_servers():
2019             ss.broken = True
2020
2021         d = self.shouldFail(NotEnoughServersError,
2022                             "test_publish_all_servers_bad",
2023                             "Ran out of non-bad servers",
2024                             nm.create_mutable_file, "contents")
2025         return d
2026
2027     def test_publish_no_servers(self):
2028         # no servers at all: the publish should fail
2029         self.basedir = "mutable/Problems/test_publish_no_servers"
2030         self.set_up_grid(num_servers=0)
2031         nm = self.g.clients[0].nodemaker
2032
2033         d = self.shouldFail(NotEnoughServersError,
2034                             "test_publish_no_servers",
2035                             "Ran out of non-bad servers",
2036                             nm.create_mutable_file, "contents")
2037         return d
2038     test_publish_no_servers.timeout = 30
2039
2040
2041     def test_privkey_query_error(self):
2042         # when a servermap is updated with MODE_WRITE, it tries to get the
2043         # privkey. Something might go wrong during this query attempt.
2044         # Exercise the code in _privkey_query_failed which tries to handle
2045         # such an error.
2046         self.basedir = "mutable/Problems/test_privkey_query_error"
2047         self.set_up_grid(num_servers=20)
2048         nm = self.g.clients[0].nodemaker
2049         nm._node_cache = DevNullDictionary() # disable the nodecache
2050
2051         # we need some contents that are large enough to push the privkey out
2052         # of the early part of the file
2053         LARGE = "These are Larger contents" * 2000 # about 50KB
2054         d = nm.create_mutable_file(LARGE)
2055         def _created(n):
2056             self.uri = n.get_uri()
2057             self.n2 = nm.create_from_cap(self.uri)
2058
2059             # When a mapupdate is performed on a node that doesn't yet know
2060             # the privkey, a short read is sent to a batch of servers, to get
2061             # the verinfo and (hopefully, if the file is short enough) the
2062             # encprivkey. Our file is too large to let this first read
2063             # contain the encprivkey. Each non-encprivkey-bearing response
2064             # that arrives (until the node gets the encprivkey) will trigger
2065             # a second read to specifically read the encprivkey.
2066             #
2067             # So, to exercise this case:
2068             #  1. notice which server gets a read() call first
2069             #  2. tell that server to start throwing errors
2070             killer = FirstServerGetsKilled()
2071             for (serverid,ss) in nm.storage_broker.get_all_servers():
2072                 ss.post_call_notifier = killer.notify
2073         d.addCallback(_created)
2074
2075         # now we update a servermap from a new node (which doesn't have the
2076         # privkey yet, forcing it to use a separate privkey query). Note that
2077         # the map-update will succeed, since we'll just get a copy from one
2078         # of the other shares.
2079         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2080
2081         return d
2082
2083     def test_privkey_query_missing(self):
2084         # like test_privkey_query_error, but the shares are deleted by the
2085         # second query, instead of raising an exception.
2086         self.basedir = "mutable/Problems/test_privkey_query_missing"
2087         self.set_up_grid(num_servers=20)
2088         nm = self.g.clients[0].nodemaker
2089         LARGE = "These are Larger contents" * 2000 # about 50KB
2090         nm._node_cache = DevNullDictionary() # disable the nodecache
2091
2092         d = nm.create_mutable_file(LARGE)
2093         def _created(n):
2094             self.uri = n.get_uri()
2095             self.n2 = nm.create_from_cap(self.uri)
2096             deleter = FirstServerGetsDeleted()
2097             for (serverid,ss) in nm.storage_broker.get_all_servers():
2098                 ss.post_call_notifier = deleter.notify
2099         d.addCallback(_created)
2100         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2101         return d