]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
Refactor StorageFarmBroker handling of servers
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from allmydata import uri, client
7 from allmydata.nodemaker import NodeMaker
8 from allmydata.util import base32
9 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
10      ssk_pubkey_fingerprint_hash
11 from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
12      NotEnoughSharesError
13 from allmydata.monitor import Monitor
14 from allmydata.test.common import ShouldFailMixin
15 from allmydata.test.no_network import GridTestMixin
16 from foolscap.api import eventually, fireEventually
17 from foolscap.logging import log
18 from allmydata.storage_client import StorageFarmBroker
19
20 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
21 from allmydata.mutable.common import ResponseCache, \
22      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
23      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
24      NotEnoughServersError, CorruptShareError
25 from allmydata.mutable.retrieve import Retrieve
26 from allmydata.mutable.publish import Publish
27 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
28 from allmydata.mutable.layout import unpack_header, unpack_share
29 from allmydata.mutable.repairer import MustForceRepairError
30
31 import allmydata.test.common_util as testutil
32
33 # this "FakeStorage" exists to put the share data in RAM and avoid using real
34 # network connections, both to speed up the tests and to reduce the amount of
35 # non-mutable.py code being exercised.
36
37 class FakeStorage:
38     # this class replaces the collection of storage servers, allowing the
39     # tests to examine and manipulate the published shares. It also lets us
40     # control the order in which read queries are answered, to exercise more
41     # of the error-handling code in Retrieve .
42     #
43     # Note that we ignore the storage index: this FakeStorage instance can
44     # only be used for a single storage index.
45
46
47     def __init__(self):
48         self._peers = {}
49         # _sequence is used to cause the responses to occur in a specific
50         # order. If it is in use, then we will defer queries instead of
51         # answering them right away, accumulating the Deferreds in a dict. We
52         # don't know exactly how many queries we'll get, so exactly one
53         # second after the first query arrives, we will release them all (in
54         # order).
55         self._sequence = None
56         self._pending = {}
57         self._pending_timer = None
58
59     def read(self, peerid, storage_index):
60         shares = self._peers.get(peerid, {})
61         if self._sequence is None:
62             return defer.succeed(shares)
63         d = defer.Deferred()
64         if not self._pending:
65             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
66         self._pending[peerid] = (d, shares)
67         return d
68
69     def _fire_readers(self):
70         self._pending_timer = None
71         pending = self._pending
72         self._pending = {}
73         for peerid in self._sequence:
74             if peerid in pending:
75                 d, shares = pending.pop(peerid)
76                 eventually(d.callback, shares)
77         for (d, shares) in pending.values():
78             eventually(d.callback, shares)
79
80     def write(self, peerid, storage_index, shnum, offset, data):
81         if peerid not in self._peers:
82             self._peers[peerid] = {}
83         shares = self._peers[peerid]
84         f = StringIO()
85         f.write(shares.get(shnum, ""))
86         f.seek(offset)
87         f.write(data)
88         shares[shnum] = f.getvalue()
89
90
91 class FakeStorageServer:
92     def __init__(self, peerid, storage):
93         self.peerid = peerid
94         self.storage = storage
95         self.queries = 0
96     def callRemote(self, methname, *args, **kwargs):
97         def _call():
98             meth = getattr(self, methname)
99             return meth(*args, **kwargs)
100         d = fireEventually()
101         d.addCallback(lambda res: _call())
102         return d
103     def callRemoteOnly(self, methname, *args, **kwargs):
104         d = self.callRemote(methname, *args, **kwargs)
105         d.addBoth(lambda ignore: None)
106         pass
107
108     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
109         pass
110
111     def slot_readv(self, storage_index, shnums, readv):
112         d = self.storage.read(self.peerid, storage_index)
113         def _read(shares):
114             response = {}
115             for shnum in shares:
116                 if shnums and shnum not in shnums:
117                     continue
118                 vector = response[shnum] = []
119                 for (offset, length) in readv:
120                     assert isinstance(offset, (int, long)), offset
121                     assert isinstance(length, (int, long)), length
122                     vector.append(shares[shnum][offset:offset+length])
123             return response
124         d.addCallback(_read)
125         return d
126
127     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
128                                         tw_vectors, read_vector):
129         # always-pass: parrot the test vectors back to them.
130         readv = {}
131         for shnum, (testv, writev, new_length) in tw_vectors.items():
132             for (offset, length, op, specimen) in testv:
133                 assert op in ("le", "eq", "ge")
134             # TODO: this isn't right, the read is controlled by read_vector,
135             # not by testv
136             readv[shnum] = [ specimen
137                              for (offset, length, op, specimen)
138                              in testv ]
139             for (offset, data) in writev:
140                 self.storage.write(self.peerid, storage_index, shnum,
141                                    offset, data)
142         answer = (True, readv)
143         return fireEventually(answer)
144
145
146 def flip_bit(original, byte_offset):
147     return (original[:byte_offset] +
148             chr(ord(original[byte_offset]) ^ 0x01) +
149             original[byte_offset+1:])
150
151 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
152     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
153     # list of shnums to corrupt.
154     for peerid in s._peers:
155         shares = s._peers[peerid]
156         for shnum in shares:
157             if (shnums_to_corrupt is not None
158                 and shnum not in shnums_to_corrupt):
159                 continue
160             data = shares[shnum]
161             (version,
162              seqnum,
163              root_hash,
164              IV,
165              k, N, segsize, datalen,
166              o) = unpack_header(data)
167             if isinstance(offset, tuple):
168                 offset1, offset2 = offset
169             else:
170                 offset1 = offset
171                 offset2 = 0
172             if offset1 == "pubkey":
173                 real_offset = 107
174             elif offset1 in o:
175                 real_offset = o[offset1]
176             else:
177                 real_offset = offset1
178             real_offset = int(real_offset) + offset2 + offset_offset
179             assert isinstance(real_offset, int), offset
180             shares[shnum] = flip_bit(data, real_offset)
181     return res
182
183 def make_storagebroker(s=None, num_peers=10):
184     if not s:
185         s = FakeStorage()
186     peerids = [tagged_hash("peerid", "%d" % i)[:20]
187                for i in range(num_peers)]
188     storage_broker = StorageFarmBroker(None, True)
189     for peerid in peerids:
190         fss = FakeStorageServer(peerid, s)
191         storage_broker.test_add_server(peerid, fss)
192     return storage_broker
193
194 def make_nodemaker(s=None, num_peers=10):
195     storage_broker = make_storagebroker(s, num_peers)
196     sh = client.SecretHolder("lease secret", "convergence secret")
197     keygen = client.KeyGenerator()
198     keygen.set_default_keysize(522)
199     nodemaker = NodeMaker(storage_broker, sh, None,
200                           None, None,
201                           {"k": 3, "n": 10}, keygen)
202     return nodemaker
203
204 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
205     # this used to be in Publish, but we removed the limit. Some of
206     # these tests test whether the new code correctly allows files
207     # larger than the limit.
208     OLD_MAX_SEGMENT_SIZE = 3500000
209     def setUp(self):
210         self._storage = s = FakeStorage()
211         self.nodemaker = make_nodemaker(s)
212
213     def test_create(self):
214         d = self.nodemaker.create_mutable_file()
215         def _created(n):
216             self.failUnless(isinstance(n, MutableFileNode))
217             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
218             sb = self.nodemaker.storage_broker
219             peer0 = sorted(sb.get_all_serverids())[0]
220             shnums = self._storage._peers[peer0].keys()
221             self.failUnlessEqual(len(shnums), 1)
222         d.addCallback(_created)
223         return d
224
225     def test_serialize(self):
226         n = MutableFileNode(None, None, {"k": 3, "n": 10}, None)
227         calls = []
228         def _callback(*args, **kwargs):
229             self.failUnlessEqual(args, (4,) )
230             self.failUnlessEqual(kwargs, {"foo": 5})
231             calls.append(1)
232             return 6
233         d = n._do_serialized(_callback, 4, foo=5)
234         def _check_callback(res):
235             self.failUnlessEqual(res, 6)
236             self.failUnlessEqual(calls, [1])
237         d.addCallback(_check_callback)
238
239         def _errback():
240             raise ValueError("heya")
241         d.addCallback(lambda res:
242                       self.shouldFail(ValueError, "_check_errback", "heya",
243                                       n._do_serialized, _errback))
244         return d
245
246     def test_upload_and_download(self):
247         d = self.nodemaker.create_mutable_file()
248         def _created(n):
249             d = defer.succeed(None)
250             d.addCallback(lambda res: n.get_servermap(MODE_READ))
251             d.addCallback(lambda smap: smap.dump(StringIO()))
252             d.addCallback(lambda sio:
253                           self.failUnless("3-of-10" in sio.getvalue()))
254             d.addCallback(lambda res: n.overwrite("contents 1"))
255             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
256             d.addCallback(lambda res: n.download_best_version())
257             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
258             d.addCallback(lambda res: n.get_size_of_best_version())
259             d.addCallback(lambda size:
260                           self.failUnlessEqual(size, len("contents 1")))
261             d.addCallback(lambda res: n.overwrite("contents 2"))
262             d.addCallback(lambda res: n.download_best_version())
263             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
264             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
265             d.addCallback(lambda smap: n.upload("contents 3", smap))
266             d.addCallback(lambda res: n.download_best_version())
267             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
268             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
269             d.addCallback(lambda smap:
270                           n.download_version(smap,
271                                              smap.best_recoverable_version()))
272             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
273             # test a file that is large enough to overcome the
274             # mapupdate-to-retrieve data caching (i.e. make the shares larger
275             # than the default readsize, which is 2000 bytes). A 15kB file
276             # will have 5kB shares.
277             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
278             d.addCallback(lambda res: n.download_best_version())
279             d.addCallback(lambda res:
280                           self.failUnlessEqual(res, "large size file" * 1000))
281             return d
282         d.addCallback(_created)
283         return d
284
285     def test_create_with_initial_contents(self):
286         d = self.nodemaker.create_mutable_file("contents 1")
287         def _created(n):
288             d = n.download_best_version()
289             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
290             d.addCallback(lambda res: n.overwrite("contents 2"))
291             d.addCallback(lambda res: n.download_best_version())
292             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
293             return d
294         d.addCallback(_created)
295         return d
296
297     def test_response_cache_memory_leak(self):
298         d = self.nodemaker.create_mutable_file("contents")
299         def _created(n):
300             d = n.download_best_version()
301             d.addCallback(lambda res: self.failUnlessEqual(res, "contents"))
302             d.addCallback(lambda ign: self.failUnless(isinstance(n._cache, ResponseCache)))
303
304             def _check_cache(expected):
305                 # The total size of cache entries should not increase on the second download;
306                 # in fact the cache contents should be identical.
307                 d2 = n.download_best_version()
308                 d2.addCallback(lambda rep: self.failUnlessEqual(repr(n._cache.cache), expected))
309                 return d2
310             d.addCallback(lambda ign: _check_cache(repr(n._cache.cache)))
311             return d
312         d.addCallback(_created)
313         return d
314
315     def test_create_with_initial_contents_function(self):
316         data = "initial contents"
317         def _make_contents(n):
318             self.failUnless(isinstance(n, MutableFileNode))
319             key = n.get_writekey()
320             self.failUnless(isinstance(key, str), key)
321             self.failUnlessEqual(len(key), 16) # AES key size
322             return data
323         d = self.nodemaker.create_mutable_file(_make_contents)
324         def _created(n):
325             return n.download_best_version()
326         d.addCallback(_created)
327         d.addCallback(lambda data2: self.failUnlessEqual(data2, data))
328         return d
329
330     def test_create_with_too_large_contents(self):
331         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
332         d = self.nodemaker.create_mutable_file(BIG)
333         def _created(n):
334             d = n.overwrite(BIG)
335             return d
336         d.addCallback(_created)
337         return d
338
339     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
340         d = n.get_servermap(MODE_READ)
341         d.addCallback(lambda servermap: servermap.best_recoverable_version())
342         d.addCallback(lambda verinfo:
343                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
344         return d
345
346     def test_modify(self):
347         def _modifier(old_contents, servermap, first_time):
348             return old_contents + "line2"
349         def _non_modifier(old_contents, servermap, first_time):
350             return old_contents
351         def _none_modifier(old_contents, servermap, first_time):
352             return None
353         def _error_modifier(old_contents, servermap, first_time):
354             raise ValueError("oops")
355         def _toobig_modifier(old_contents, servermap, first_time):
356             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
357         calls = []
358         def _ucw_error_modifier(old_contents, servermap, first_time):
359             # simulate an UncoordinatedWriteError once
360             calls.append(1)
361             if len(calls) <= 1:
362                 raise UncoordinatedWriteError("simulated")
363             return old_contents + "line3"
364         def _ucw_error_non_modifier(old_contents, servermap, first_time):
365             # simulate an UncoordinatedWriteError once, and don't actually
366             # modify the contents on subsequent invocations
367             calls.append(1)
368             if len(calls) <= 1:
369                 raise UncoordinatedWriteError("simulated")
370             return old_contents
371
372         d = self.nodemaker.create_mutable_file("line1")
373         def _created(n):
374             d = n.modify(_modifier)
375             d.addCallback(lambda res: n.download_best_version())
376             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
377             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
378
379             d.addCallback(lambda res: n.modify(_non_modifier))
380             d.addCallback(lambda res: n.download_best_version())
381             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
382             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
383
384             d.addCallback(lambda res: n.modify(_none_modifier))
385             d.addCallback(lambda res: n.download_best_version())
386             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
387             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
388
389             d.addCallback(lambda res:
390                           self.shouldFail(ValueError, "error_modifier", None,
391                                           n.modify, _error_modifier))
392             d.addCallback(lambda res: n.download_best_version())
393             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
394             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
395
396
397             d.addCallback(lambda res: n.download_best_version())
398             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
399             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
400
401             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
402             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
403             d.addCallback(lambda res: n.download_best_version())
404             d.addCallback(lambda res: self.failUnlessEqual(res,
405                                                            "line1line2line3"))
406             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
407
408             def _reset_ucw_error_modifier(res):
409                 calls[:] = []
410                 return res
411             d.addCallback(_reset_ucw_error_modifier)
412
413             # in practice, this n.modify call should publish twice: the first
414             # one gets a UCWE, the second does not. But our test jig (in
415             # which the modifier raises the UCWE) skips over the first one,
416             # so in this test there will be only one publish, and the seqnum
417             # will only be one larger than the previous test, not two (i.e. 4
418             # instead of 5).
419             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
420             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
421             d.addCallback(lambda res: n.download_best_version())
422             d.addCallback(lambda res: self.failUnlessEqual(res,
423                                                            "line1line2line3"))
424             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
425             d.addCallback(lambda res: n.modify(_toobig_modifier))
426             return d
427         d.addCallback(_created)
428         return d
429
430     def test_modify_backoffer(self):
431         def _modifier(old_contents, servermap, first_time):
432             return old_contents + "line2"
433         calls = []
434         def _ucw_error_modifier(old_contents, servermap, first_time):
435             # simulate an UncoordinatedWriteError once
436             calls.append(1)
437             if len(calls) <= 1:
438                 raise UncoordinatedWriteError("simulated")
439             return old_contents + "line3"
440         def _always_ucw_error_modifier(old_contents, servermap, first_time):
441             raise UncoordinatedWriteError("simulated")
442         def _backoff_stopper(node, f):
443             return f
444         def _backoff_pauser(node, f):
445             d = defer.Deferred()
446             reactor.callLater(0.5, d.callback, None)
447             return d
448
449         # the give-up-er will hit its maximum retry count quickly
450         giveuper = BackoffAgent()
451         giveuper._delay = 0.1
452         giveuper.factor = 1
453
454         d = self.nodemaker.create_mutable_file("line1")
455         def _created(n):
456             d = n.modify(_modifier)
457             d.addCallback(lambda res: n.download_best_version())
458             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
459             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
460
461             d.addCallback(lambda res:
462                           self.shouldFail(UncoordinatedWriteError,
463                                           "_backoff_stopper", None,
464                                           n.modify, _ucw_error_modifier,
465                                           _backoff_stopper))
466             d.addCallback(lambda res: n.download_best_version())
467             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
468             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
469
470             def _reset_ucw_error_modifier(res):
471                 calls[:] = []
472                 return res
473             d.addCallback(_reset_ucw_error_modifier)
474             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
475                                                _backoff_pauser))
476             d.addCallback(lambda res: n.download_best_version())
477             d.addCallback(lambda res: self.failUnlessEqual(res,
478                                                            "line1line2line3"))
479             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
480
481             d.addCallback(lambda res:
482                           self.shouldFail(UncoordinatedWriteError,
483                                           "giveuper", None,
484                                           n.modify, _always_ucw_error_modifier,
485                                           giveuper.delay))
486             d.addCallback(lambda res: n.download_best_version())
487             d.addCallback(lambda res: self.failUnlessEqual(res,
488                                                            "line1line2line3"))
489             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
490
491             return d
492         d.addCallback(_created)
493         return d
494
495     def test_upload_and_download_full_size_keys(self):
496         self.nodemaker.key_generator = client.KeyGenerator()
497         d = self.nodemaker.create_mutable_file()
498         def _created(n):
499             d = defer.succeed(None)
500             d.addCallback(lambda res: n.get_servermap(MODE_READ))
501             d.addCallback(lambda smap: smap.dump(StringIO()))
502             d.addCallback(lambda sio:
503                           self.failUnless("3-of-10" in sio.getvalue()))
504             d.addCallback(lambda res: n.overwrite("contents 1"))
505             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
506             d.addCallback(lambda res: n.download_best_version())
507             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
508             d.addCallback(lambda res: n.overwrite("contents 2"))
509             d.addCallback(lambda res: n.download_best_version())
510             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
511             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
512             d.addCallback(lambda smap: n.upload("contents 3", smap))
513             d.addCallback(lambda res: n.download_best_version())
514             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
515             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
516             d.addCallback(lambda smap:
517                           n.download_version(smap,
518                                              smap.best_recoverable_version()))
519             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
520             return d
521         d.addCallback(_created)
522         return d
523
524
525 class MakeShares(unittest.TestCase):
526     def test_encrypt(self):
527         nm = make_nodemaker()
528         CONTENTS = "some initial contents"
529         d = nm.create_mutable_file(CONTENTS)
530         def _created(fn):
531             p = Publish(fn, nm.storage_broker, None)
532             p.salt = "SALT" * 4
533             p.readkey = "\x00" * 16
534             p.newdata = CONTENTS
535             p.required_shares = 3
536             p.total_shares = 10
537             p.setup_encoding_parameters()
538             return p._encrypt_and_encode()
539         d.addCallback(_created)
540         def _done(shares_and_shareids):
541             (shares, share_ids) = shares_and_shareids
542             self.failUnlessEqual(len(shares), 10)
543             for sh in shares:
544                 self.failUnless(isinstance(sh, str))
545                 self.failUnlessEqual(len(sh), 7)
546             self.failUnlessEqual(len(share_ids), 10)
547         d.addCallback(_done)
548         return d
549
550     def test_generate(self):
551         nm = make_nodemaker()
552         CONTENTS = "some initial contents"
553         d = nm.create_mutable_file(CONTENTS)
554         def _created(fn):
555             self._fn = fn
556             p = Publish(fn, nm.storage_broker, None)
557             self._p = p
558             p.newdata = CONTENTS
559             p.required_shares = 3
560             p.total_shares = 10
561             p.setup_encoding_parameters()
562             p._new_seqnum = 3
563             p.salt = "SALT" * 4
564             # make some fake shares
565             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
566             p._privkey = fn.get_privkey()
567             p._encprivkey = fn.get_encprivkey()
568             p._pubkey = fn.get_pubkey()
569             return p._generate_shares(shares_and_ids)
570         d.addCallback(_created)
571         def _generated(res):
572             p = self._p
573             final_shares = p.shares
574             root_hash = p.root_hash
575             self.failUnlessEqual(len(root_hash), 32)
576             self.failUnless(isinstance(final_shares, dict))
577             self.failUnlessEqual(len(final_shares), 10)
578             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
579             for i,sh in final_shares.items():
580                 self.failUnless(isinstance(sh, str))
581                 # feed the share through the unpacker as a sanity-check
582                 pieces = unpack_share(sh)
583                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
584                  pubkey, signature, share_hash_chain, block_hash_tree,
585                  share_data, enc_privkey) = pieces
586                 self.failUnlessEqual(u_seqnum, 3)
587                 self.failUnlessEqual(u_root_hash, root_hash)
588                 self.failUnlessEqual(k, 3)
589                 self.failUnlessEqual(N, 10)
590                 self.failUnlessEqual(segsize, 21)
591                 self.failUnlessEqual(datalen, len(CONTENTS))
592                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
593                 sig_material = struct.pack(">BQ32s16s BBQQ",
594                                            0, p._new_seqnum, root_hash, IV,
595                                            k, N, segsize, datalen)
596                 self.failUnless(p._pubkey.verify(sig_material, signature))
597                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
598                 self.failUnless(isinstance(share_hash_chain, dict))
599                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
600                 for shnum,share_hash in share_hash_chain.items():
601                     self.failUnless(isinstance(shnum, int))
602                     self.failUnless(isinstance(share_hash, str))
603                     self.failUnlessEqual(len(share_hash), 32)
604                 self.failUnless(isinstance(block_hash_tree, list))
605                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
606                 self.failUnlessEqual(IV, "SALT"*4)
607                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
608                 self.failUnlessEqual(enc_privkey, self._fn.get_encprivkey())
609         d.addCallback(_generated)
610         return d
611
612     # TODO: when we publish to 20 peers, we should get one share per peer on 10
613     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
614     # when we publish to zero peers, we should get a NotEnoughSharesError
615
616 class PublishMixin:
617     def publish_one(self):
618         # publish a file and create shares, which can then be manipulated
619         # later.
620         self.CONTENTS = "New contents go here" * 1000
621         self._storage = FakeStorage()
622         self._nodemaker = make_nodemaker(self._storage)
623         self._storage_broker = self._nodemaker.storage_broker
624         d = self._nodemaker.create_mutable_file(self.CONTENTS)
625         def _created(node):
626             self._fn = node
627             self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
628         d.addCallback(_created)
629         return d
630
631     def publish_multiple(self):
632         self.CONTENTS = ["Contents 0",
633                          "Contents 1",
634                          "Contents 2",
635                          "Contents 3a",
636                          "Contents 3b"]
637         self._copied_shares = {}
638         self._storage = FakeStorage()
639         self._nodemaker = make_nodemaker(self._storage)
640         d = self._nodemaker.create_mutable_file(self.CONTENTS[0]) # seqnum=1
641         def _created(node):
642             self._fn = node
643             # now create multiple versions of the same file, and accumulate
644             # their shares, so we can mix and match them later.
645             d = defer.succeed(None)
646             d.addCallback(self._copy_shares, 0)
647             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
648             d.addCallback(self._copy_shares, 1)
649             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
650             d.addCallback(self._copy_shares, 2)
651             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
652             d.addCallback(self._copy_shares, 3)
653             # now we replace all the shares with version s3, and upload a new
654             # version to get s4b.
655             rollback = dict([(i,2) for i in range(10)])
656             d.addCallback(lambda res: self._set_versions(rollback))
657             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
658             d.addCallback(self._copy_shares, 4)
659             # we leave the storage in state 4
660             return d
661         d.addCallback(_created)
662         return d
663
664     def _copy_shares(self, ignored, index):
665         shares = self._storage._peers
666         # we need a deep copy
667         new_shares = {}
668         for peerid in shares:
669             new_shares[peerid] = {}
670             for shnum in shares[peerid]:
671                 new_shares[peerid][shnum] = shares[peerid][shnum]
672         self._copied_shares[index] = new_shares
673
674     def _set_versions(self, versionmap):
675         # versionmap maps shnums to which version (0,1,2,3,4) we want the
676         # share to be at. Any shnum which is left out of the map will stay at
677         # its current version.
678         shares = self._storage._peers
679         oldshares = self._copied_shares
680         for peerid in shares:
681             for shnum in shares[peerid]:
682                 if shnum in versionmap:
683                     index = versionmap[shnum]
684                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
685
686
687 class Servermap(unittest.TestCase, PublishMixin):
688     def setUp(self):
689         return self.publish_one()
690
691     def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None):
692         if fn is None:
693             fn = self._fn
694         if sb is None:
695             sb = self._storage_broker
696         smu = ServermapUpdater(fn, sb, Monitor(),
697                                ServerMap(), mode)
698         d = smu.update()
699         return d
700
701     def update_servermap(self, oldmap, mode=MODE_CHECK):
702         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
703                                oldmap, mode)
704         d = smu.update()
705         return d
706
707     def failUnlessOneRecoverable(self, sm, num_shares):
708         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
709         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
710         best = sm.best_recoverable_version()
711         self.failIfEqual(best, None)
712         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
713         self.failUnlessEqual(len(sm.shares_available()), 1)
714         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
715         shnum, peerids = sm.make_sharemap().items()[0]
716         peerid = list(peerids)[0]
717         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
718         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
719         return sm
720
721     def test_basic(self):
722         d = defer.succeed(None)
723         ms = self.make_servermap
724         us = self.update_servermap
725
726         d.addCallback(lambda res: ms(mode=MODE_CHECK))
727         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
728         d.addCallback(lambda res: ms(mode=MODE_WRITE))
729         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
730         d.addCallback(lambda res: ms(mode=MODE_READ))
731         # this mode stops at k+epsilon, and epsilon=k, so 6 shares
732         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
733         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
734         # this mode stops at 'k' shares
735         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
736
737         # and can we re-use the same servermap? Note that these are sorted in
738         # increasing order of number of servers queried, since once a server
739         # gets into the servermap, we'll always ask it for an update.
740         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
741         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
742         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
743         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
744         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
745         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
746         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
747         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
748         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
749
750         return d
751
752     def test_fetch_privkey(self):
753         d = defer.succeed(None)
754         # use the sibling filenode (which hasn't been used yet), and make
755         # sure it can fetch the privkey. The file is small, so the privkey
756         # will be fetched on the first (query) pass.
757         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
758         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
759
760         # create a new file, which is large enough to knock the privkey out
761         # of the early part of the file
762         LARGE = "These are Larger contents" * 200 # about 5KB
763         d.addCallback(lambda res: self._nodemaker.create_mutable_file(LARGE))
764         def _created(large_fn):
765             large_fn2 = self._nodemaker.create_from_cap(large_fn.get_uri())
766             return self.make_servermap(MODE_WRITE, large_fn2)
767         d.addCallback(_created)
768         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
769         return d
770
771     def test_mark_bad(self):
772         d = defer.succeed(None)
773         ms = self.make_servermap
774
775         d.addCallback(lambda res: ms(mode=MODE_READ))
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
777         def _made_map(sm):
778             v = sm.best_recoverable_version()
779             vm = sm.make_versionmap()
780             shares = list(vm[v])
781             self.failUnlessEqual(len(shares), 6)
782             self._corrupted = set()
783             # mark the first 5 shares as corrupt, then update the servermap.
784             # The map should not have the marked shares it in any more, and
785             # new shares should be found to replace the missing ones.
786             for (shnum, peerid, timestamp) in shares:
787                 if shnum < 5:
788                     self._corrupted.add( (peerid, shnum) )
789                     sm.mark_bad_share(peerid, shnum, "")
790             return self.update_servermap(sm, MODE_WRITE)
791         d.addCallback(_made_map)
792         def _check_map(sm):
793             # this should find all 5 shares that weren't marked bad
794             v = sm.best_recoverable_version()
795             vm = sm.make_versionmap()
796             shares = list(vm[v])
797             for (peerid, shnum) in self._corrupted:
798                 peer_shares = sm.shares_on_peer(peerid)
799                 self.failIf(shnum in peer_shares,
800                             "%d was in %s" % (shnum, peer_shares))
801             self.failUnlessEqual(len(shares), 5)
802         d.addCallback(_check_map)
803         return d
804
805     def failUnlessNoneRecoverable(self, sm):
806         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
807         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
808         best = sm.best_recoverable_version()
809         self.failUnlessEqual(best, None)
810         self.failUnlessEqual(len(sm.shares_available()), 0)
811
812     def test_no_shares(self):
813         self._storage._peers = {} # delete all shares
814         ms = self.make_servermap
815         d = defer.succeed(None)
816
817         d.addCallback(lambda res: ms(mode=MODE_CHECK))
818         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
819
820         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
821         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
822
823         d.addCallback(lambda res: ms(mode=MODE_WRITE))
824         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
825
826         d.addCallback(lambda res: ms(mode=MODE_READ))
827         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
828
829         return d
830
831     def failUnlessNotQuiteEnough(self, sm):
832         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
833         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
834         best = sm.best_recoverable_version()
835         self.failUnlessEqual(best, None)
836         self.failUnlessEqual(len(sm.shares_available()), 1)
837         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
838         return sm
839
840     def test_not_quite_enough_shares(self):
841         s = self._storage
842         ms = self.make_servermap
843         num_shares = len(s._peers)
844         for peerid in s._peers:
845             s._peers[peerid] = {}
846             num_shares -= 1
847             if num_shares == 2:
848                 break
849         # now there ought to be only two shares left
850         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
851
852         d = defer.succeed(None)
853
854         d.addCallback(lambda res: ms(mode=MODE_CHECK))
855         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
856         d.addCallback(lambda sm:
857                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
858         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
859         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
860         d.addCallback(lambda res: ms(mode=MODE_WRITE))
861         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
862         d.addCallback(lambda res: ms(mode=MODE_READ))
863         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
864
865         return d
866
867
868
869 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
870     def setUp(self):
871         return self.publish_one()
872
873     def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None):
874         if oldmap is None:
875             oldmap = ServerMap()
876         if sb is None:
877             sb = self._storage_broker
878         smu = ServermapUpdater(self._fn, sb, Monitor(), oldmap, mode)
879         d = smu.update()
880         return d
881
882     def abbrev_verinfo(self, verinfo):
883         if verinfo is None:
884             return None
885         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
886          offsets_tuple) = verinfo
887         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
888
889     def abbrev_verinfo_dict(self, verinfo_d):
890         output = {}
891         for verinfo,value in verinfo_d.items():
892             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
893              offsets_tuple) = verinfo
894             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
895         return output
896
897     def dump_servermap(self, servermap):
898         print "SERVERMAP", servermap
899         print "RECOVERABLE", [self.abbrev_verinfo(v)
900                               for v in servermap.recoverable_versions()]
901         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
902         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
903
904     def do_download(self, servermap, version=None):
905         if version is None:
906             version = servermap.best_recoverable_version()
907         r = Retrieve(self._fn, servermap, version)
908         return r.download()
909
910     def test_basic(self):
911         d = self.make_servermap()
912         def _do_retrieve(servermap):
913             self._smap = servermap
914             #self.dump_servermap(servermap)
915             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
916             return self.do_download(servermap)
917         d.addCallback(_do_retrieve)
918         def _retrieved(new_contents):
919             self.failUnlessEqual(new_contents, self.CONTENTS)
920         d.addCallback(_retrieved)
921         # we should be able to re-use the same servermap, both with and
922         # without updating it.
923         d.addCallback(lambda res: self.do_download(self._smap))
924         d.addCallback(_retrieved)
925         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
926         d.addCallback(lambda res: self.do_download(self._smap))
927         d.addCallback(_retrieved)
928         # clobbering the pubkey should make the servermap updater re-fetch it
929         def _clobber_pubkey(res):
930             self._fn._pubkey = None
931         d.addCallback(_clobber_pubkey)
932         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
933         d.addCallback(lambda res: self.do_download(self._smap))
934         d.addCallback(_retrieved)
935         return d
936
937     def test_all_shares_vanished(self):
938         d = self.make_servermap()
939         def _remove_shares(servermap):
940             for shares in self._storage._peers.values():
941                 shares.clear()
942             d1 = self.shouldFail(NotEnoughSharesError,
943                                  "test_all_shares_vanished",
944                                  "ran out of peers",
945                                  self.do_download, servermap)
946             return d1
947         d.addCallback(_remove_shares)
948         return d
949
950     def test_no_servers(self):
951         sb2 = make_storagebroker(num_peers=0)
952         # if there are no servers, then a MODE_READ servermap should come
953         # back empty
954         d = self.make_servermap(sb=sb2)
955         def _check_servermap(servermap):
956             self.failUnlessEqual(servermap.best_recoverable_version(), None)
957             self.failIf(servermap.recoverable_versions())
958             self.failIf(servermap.unrecoverable_versions())
959             self.failIf(servermap.all_peers())
960         d.addCallback(_check_servermap)
961         return d
962     test_no_servers.timeout = 15
963
964     def test_no_servers_download(self):
965         sb2 = make_storagebroker(num_peers=0)
966         self._fn._storage_broker = sb2
967         d = self.shouldFail(UnrecoverableFileError,
968                             "test_no_servers_download",
969                             "no recoverable versions",
970                             self._fn.download_best_version)
971         def _restore(res):
972             # a failed download that occurs while we aren't connected to
973             # anybody should not prevent a subsequent download from working.
974             # This isn't quite the webapi-driven test that #463 wants, but it
975             # should be close enough.
976             self._fn._storage_broker = self._storage_broker
977             return self._fn.download_best_version()
978         def _retrieved(new_contents):
979             self.failUnlessEqual(new_contents, self.CONTENTS)
980         d.addCallback(_restore)
981         d.addCallback(_retrieved)
982         return d
983     test_no_servers_download.timeout = 15
984
985     def _test_corrupt_all(self, offset, substring,
986                           should_succeed=False, corrupt_early=True,
987                           failure_checker=None):
988         d = defer.succeed(None)
989         if corrupt_early:
990             d.addCallback(corrupt, self._storage, offset)
991         d.addCallback(lambda res: self.make_servermap())
992         if not corrupt_early:
993             d.addCallback(corrupt, self._storage, offset)
994         def _do_retrieve(servermap):
995             ver = servermap.best_recoverable_version()
996             if ver is None and not should_succeed:
997                 # no recoverable versions == not succeeding. The problem
998                 # should be noted in the servermap's list of problems.
999                 if substring:
1000                     allproblems = [str(f) for f in servermap.problems]
1001                     self.failUnlessIn(substring, "".join(allproblems))
1002                 return servermap
1003             if should_succeed:
1004                 d1 = self._fn.download_version(servermap, ver)
1005                 d1.addCallback(lambda new_contents:
1006                                self.failUnlessEqual(new_contents, self.CONTENTS))
1007             else:
1008                 d1 = self.shouldFail(NotEnoughSharesError,
1009                                      "_corrupt_all(offset=%s)" % (offset,),
1010                                      substring,
1011                                      self._fn.download_version, servermap, ver)
1012             if failure_checker:
1013                 d1.addCallback(failure_checker)
1014             d1.addCallback(lambda res: servermap)
1015             return d1
1016         d.addCallback(_do_retrieve)
1017         return d
1018
1019     def test_corrupt_all_verbyte(self):
1020         # when the version byte is not 0, we hit an UnknownVersionError error
1021         # in unpack_share().
1022         d = self._test_corrupt_all(0, "UnknownVersionError")
1023         def _check_servermap(servermap):
1024             # and the dump should mention the problems
1025             s = StringIO()
1026             dump = servermap.dump(s).getvalue()
1027             self.failUnless("10 PROBLEMS" in dump, dump)
1028         d.addCallback(_check_servermap)
1029         return d
1030
1031     def test_corrupt_all_seqnum(self):
1032         # a corrupt sequence number will trigger a bad signature
1033         return self._test_corrupt_all(1, "signature is invalid")
1034
1035     def test_corrupt_all_R(self):
1036         # a corrupt root hash will trigger a bad signature
1037         return self._test_corrupt_all(9, "signature is invalid")
1038
1039     def test_corrupt_all_IV(self):
1040         # a corrupt salt/IV will trigger a bad signature
1041         return self._test_corrupt_all(41, "signature is invalid")
1042
1043     def test_corrupt_all_k(self):
1044         # a corrupt 'k' will trigger a bad signature
1045         return self._test_corrupt_all(57, "signature is invalid")
1046
1047     def test_corrupt_all_N(self):
1048         # a corrupt 'N' will trigger a bad signature
1049         return self._test_corrupt_all(58, "signature is invalid")
1050
1051     def test_corrupt_all_segsize(self):
1052         # a corrupt segsize will trigger a bad signature
1053         return self._test_corrupt_all(59, "signature is invalid")
1054
1055     def test_corrupt_all_datalen(self):
1056         # a corrupt data length will trigger a bad signature
1057         return self._test_corrupt_all(67, "signature is invalid")
1058
1059     def test_corrupt_all_pubkey(self):
1060         # a corrupt pubkey won't match the URI's fingerprint. We need to
1061         # remove the pubkey from the filenode, or else it won't bother trying
1062         # to update it.
1063         self._fn._pubkey = None
1064         return self._test_corrupt_all("pubkey",
1065                                       "pubkey doesn't match fingerprint")
1066
1067     def test_corrupt_all_sig(self):
1068         # a corrupt signature is a bad one
1069         # the signature runs from about [543:799], depending upon the length
1070         # of the pubkey
1071         return self._test_corrupt_all("signature", "signature is invalid")
1072
1073     def test_corrupt_all_share_hash_chain_number(self):
1074         # a corrupt share hash chain entry will show up as a bad hash. If we
1075         # mangle the first byte, that will look like a bad hash number,
1076         # causing an IndexError
1077         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1078
1079     def test_corrupt_all_share_hash_chain_hash(self):
1080         # a corrupt share hash chain entry will show up as a bad hash. If we
1081         # mangle a few bytes in, that will look like a bad hash.
1082         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1083
1084     def test_corrupt_all_block_hash_tree(self):
1085         return self._test_corrupt_all("block_hash_tree",
1086                                       "block hash tree failure")
1087
1088     def test_corrupt_all_block(self):
1089         return self._test_corrupt_all("share_data", "block hash tree failure")
1090
1091     def test_corrupt_all_encprivkey(self):
1092         # a corrupted privkey won't even be noticed by the reader, only by a
1093         # writer.
1094         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1095
1096
1097     def test_corrupt_all_seqnum_late(self):
1098         # corrupting the seqnum between mapupdate and retrieve should result
1099         # in NotEnoughSharesError, since each share will look invalid
1100         def _check(res):
1101             f = res[0]
1102             self.failUnless(f.check(NotEnoughSharesError))
1103             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1104         return self._test_corrupt_all(1, "ran out of peers",
1105                                       corrupt_early=False,
1106                                       failure_checker=_check)
1107
1108     def test_corrupt_all_block_hash_tree_late(self):
1109         def _check(res):
1110             f = res[0]
1111             self.failUnless(f.check(NotEnoughSharesError))
1112         return self._test_corrupt_all("block_hash_tree",
1113                                       "block hash tree failure",
1114                                       corrupt_early=False,
1115                                       failure_checker=_check)
1116
1117
1118     def test_corrupt_all_block_late(self):
1119         def _check(res):
1120             f = res[0]
1121             self.failUnless(f.check(NotEnoughSharesError))
1122         return self._test_corrupt_all("share_data", "block hash tree failure",
1123                                       corrupt_early=False,
1124                                       failure_checker=_check)
1125
1126
1127     def test_basic_pubkey_at_end(self):
1128         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1129         # download to succeed but forcing a bunch of retries first. Note that
1130         # this is rather pessimistic: our Retrieve process will throw away
1131         # the whole share if the pubkey is bad, even though the rest of the
1132         # share might be good.
1133
1134         self._fn._pubkey = None
1135         k = self._fn.get_required_shares()
1136         N = self._fn.get_total_shares()
1137         d = defer.succeed(None)
1138         d.addCallback(corrupt, self._storage, "pubkey",
1139                       shnums_to_corrupt=range(0, N-k))
1140         d.addCallback(lambda res: self.make_servermap())
1141         def _do_retrieve(servermap):
1142             self.failUnless(servermap.problems)
1143             self.failUnless("pubkey doesn't match fingerprint"
1144                             in str(servermap.problems[0]))
1145             ver = servermap.best_recoverable_version()
1146             r = Retrieve(self._fn, servermap, ver)
1147             return r.download()
1148         d.addCallback(_do_retrieve)
1149         d.addCallback(lambda new_contents:
1150                       self.failUnlessEqual(new_contents, self.CONTENTS))
1151         return d
1152
1153     def test_corrupt_some(self):
1154         # corrupt the data of first five shares (so the servermap thinks
1155         # they're good but retrieve marks them as bad), so that the
1156         # MODE_READ set of 6 will be insufficient, forcing node.download to
1157         # retry with more servers.
1158         corrupt(None, self._storage, "share_data", range(5))
1159         d = self.make_servermap()
1160         def _do_retrieve(servermap):
1161             ver = servermap.best_recoverable_version()
1162             self.failUnless(ver)
1163             return self._fn.download_best_version()
1164         d.addCallback(_do_retrieve)
1165         d.addCallback(lambda new_contents:
1166                       self.failUnlessEqual(new_contents, self.CONTENTS))
1167         return d
1168
1169     def test_download_fails(self):
1170         corrupt(None, self._storage, "signature")
1171         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1172                             "no recoverable versions",
1173                             self._fn.download_best_version)
1174         return d
1175
1176
1177 class CheckerMixin:
1178     def check_good(self, r, where):
1179         self.failUnless(r.is_healthy(), where)
1180         return r
1181
1182     def check_bad(self, r, where):
1183         self.failIf(r.is_healthy(), where)
1184         return r
1185
1186     def check_expected_failure(self, r, expected_exception, substring, where):
1187         for (peerid, storage_index, shnum, f) in r.problems:
1188             if f.check(expected_exception):
1189                 self.failUnless(substring in str(f),
1190                                 "%s: substring '%s' not in '%s'" %
1191                                 (where, substring, str(f)))
1192                 return
1193         self.fail("%s: didn't see expected exception %s in problems %s" %
1194                   (where, expected_exception, r.problems))
1195
1196
1197 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1198     def setUp(self):
1199         return self.publish_one()
1200
1201
1202     def test_check_good(self):
1203         d = self._fn.check(Monitor())
1204         d.addCallback(self.check_good, "test_check_good")
1205         return d
1206
1207     def test_check_no_shares(self):
1208         for shares in self._storage._peers.values():
1209             shares.clear()
1210         d = self._fn.check(Monitor())
1211         d.addCallback(self.check_bad, "test_check_no_shares")
1212         return d
1213
1214     def test_check_not_enough_shares(self):
1215         for shares in self._storage._peers.values():
1216             for shnum in shares.keys():
1217                 if shnum > 0:
1218                     del shares[shnum]
1219         d = self._fn.check(Monitor())
1220         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1221         return d
1222
1223     def test_check_all_bad_sig(self):
1224         corrupt(None, self._storage, 1) # bad sig
1225         d = self._fn.check(Monitor())
1226         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1227         return d
1228
1229     def test_check_all_bad_blocks(self):
1230         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1231         # the Checker won't notice this.. it doesn't look at actual data
1232         d = self._fn.check(Monitor())
1233         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1234         return d
1235
1236     def test_verify_good(self):
1237         d = self._fn.check(Monitor(), verify=True)
1238         d.addCallback(self.check_good, "test_verify_good")
1239         return d
1240
1241     def test_verify_all_bad_sig(self):
1242         corrupt(None, self._storage, 1) # bad sig
1243         d = self._fn.check(Monitor(), verify=True)
1244         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1245         return d
1246
1247     def test_verify_one_bad_sig(self):
1248         corrupt(None, self._storage, 1, [9]) # bad sig
1249         d = self._fn.check(Monitor(), verify=True)
1250         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1251         return d
1252
1253     def test_verify_one_bad_block(self):
1254         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1255         # the Verifier *will* notice this, since it examines every byte
1256         d = self._fn.check(Monitor(), verify=True)
1257         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1258         d.addCallback(self.check_expected_failure,
1259                       CorruptShareError, "block hash tree failure",
1260                       "test_verify_one_bad_block")
1261         return d
1262
1263     def test_verify_one_bad_sharehash(self):
1264         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1265         d = self._fn.check(Monitor(), verify=True)
1266         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1267         d.addCallback(self.check_expected_failure,
1268                       CorruptShareError, "corrupt hashes",
1269                       "test_verify_one_bad_sharehash")
1270         return d
1271
1272     def test_verify_one_bad_encprivkey(self):
1273         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1274         d = self._fn.check(Monitor(), verify=True)
1275         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1276         d.addCallback(self.check_expected_failure,
1277                       CorruptShareError, "invalid privkey",
1278                       "test_verify_one_bad_encprivkey")
1279         return d
1280
1281     def test_verify_one_bad_encprivkey_uncheckable(self):
1282         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1283         readonly_fn = self._fn.get_readonly()
1284         # a read-only node has no way to validate the privkey
1285         d = readonly_fn.check(Monitor(), verify=True)
1286         d.addCallback(self.check_good,
1287                       "test_verify_one_bad_encprivkey_uncheckable")
1288         return d
1289
1290 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1291
1292     def get_shares(self, s):
1293         all_shares = {} # maps (peerid, shnum) to share data
1294         for peerid in s._peers:
1295             shares = s._peers[peerid]
1296             for shnum in shares:
1297                 data = shares[shnum]
1298                 all_shares[ (peerid, shnum) ] = data
1299         return all_shares
1300
1301     def copy_shares(self, ignored=None):
1302         self.old_shares.append(self.get_shares(self._storage))
1303
1304     def test_repair_nop(self):
1305         self.old_shares = []
1306         d = self.publish_one()
1307         d.addCallback(self.copy_shares)
1308         d.addCallback(lambda res: self._fn.check(Monitor()))
1309         d.addCallback(lambda check_results: self._fn.repair(check_results))
1310         def _check_results(rres):
1311             self.failUnless(IRepairResults.providedBy(rres))
1312             self.failUnless(rres.get_successful())
1313             # TODO: examine results
1314
1315             self.copy_shares()
1316
1317             initial_shares = self.old_shares[0]
1318             new_shares = self.old_shares[1]
1319             # TODO: this really shouldn't change anything. When we implement
1320             # a "minimal-bandwidth" repairer", change this test to assert:
1321             #self.failUnlessEqual(new_shares, initial_shares)
1322
1323             # all shares should be in the same place as before
1324             self.failUnlessEqual(set(initial_shares.keys()),
1325                                  set(new_shares.keys()))
1326             # but they should all be at a newer seqnum. The IV will be
1327             # different, so the roothash will be too.
1328             for key in initial_shares:
1329                 (version0,
1330                  seqnum0,
1331                  root_hash0,
1332                  IV0,
1333                  k0, N0, segsize0, datalen0,
1334                  o0) = unpack_header(initial_shares[key])
1335                 (version1,
1336                  seqnum1,
1337                  root_hash1,
1338                  IV1,
1339                  k1, N1, segsize1, datalen1,
1340                  o1) = unpack_header(new_shares[key])
1341                 self.failUnlessEqual(version0, version1)
1342                 self.failUnlessEqual(seqnum0+1, seqnum1)
1343                 self.failUnlessEqual(k0, k1)
1344                 self.failUnlessEqual(N0, N1)
1345                 self.failUnlessEqual(segsize0, segsize1)
1346                 self.failUnlessEqual(datalen0, datalen1)
1347         d.addCallback(_check_results)
1348         return d
1349
1350     def failIfSharesChanged(self, ignored=None):
1351         old_shares = self.old_shares[-2]
1352         current_shares = self.old_shares[-1]
1353         self.failUnlessEqual(old_shares, current_shares)
1354
1355     def test_unrepairable_0shares(self):
1356         d = self.publish_one()
1357         def _delete_all_shares(ign):
1358             shares = self._storage._peers
1359             for peerid in shares:
1360                 shares[peerid] = {}
1361         d.addCallback(_delete_all_shares)
1362         d.addCallback(lambda ign: self._fn.check(Monitor()))
1363         d.addCallback(lambda check_results: self._fn.repair(check_results))
1364         def _check(crr):
1365             self.failUnlessEqual(crr.get_successful(), False)
1366         d.addCallback(_check)
1367         return d
1368
1369     def test_unrepairable_1share(self):
1370         d = self.publish_one()
1371         def _delete_all_shares(ign):
1372             shares = self._storage._peers
1373             for peerid in shares:
1374                 for shnum in list(shares[peerid]):
1375                     if shnum > 0:
1376                         del shares[peerid][shnum]
1377         d.addCallback(_delete_all_shares)
1378         d.addCallback(lambda ign: self._fn.check(Monitor()))
1379         d.addCallback(lambda check_results: self._fn.repair(check_results))
1380         def _check(crr):
1381             self.failUnlessEqual(crr.get_successful(), False)
1382         d.addCallback(_check)
1383         return d
1384
1385     def test_merge(self):
1386         self.old_shares = []
1387         d = self.publish_multiple()
1388         # repair will refuse to merge multiple highest seqnums unless you
1389         # pass force=True
1390         d.addCallback(lambda res:
1391                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1392                                           1:4,3:4,5:4,7:4,9:4}))
1393         d.addCallback(self.copy_shares)
1394         d.addCallback(lambda res: self._fn.check(Monitor()))
1395         def _try_repair(check_results):
1396             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1397             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1398                                  self._fn.repair, check_results)
1399             d2.addCallback(self.copy_shares)
1400             d2.addCallback(self.failIfSharesChanged)
1401             d2.addCallback(lambda res: check_results)
1402             return d2
1403         d.addCallback(_try_repair)
1404         d.addCallback(lambda check_results:
1405                       self._fn.repair(check_results, force=True))
1406         # this should give us 10 shares of the highest roothash
1407         def _check_repair_results(rres):
1408             self.failUnless(rres.get_successful())
1409             pass # TODO
1410         d.addCallback(_check_repair_results)
1411         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1412         def _check_smap(smap):
1413             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1414             self.failIf(smap.unrecoverable_versions())
1415             # now, which should have won?
1416             roothash_s4a = self.get_roothash_for(3)
1417             roothash_s4b = self.get_roothash_for(4)
1418             if roothash_s4b > roothash_s4a:
1419                 expected_contents = self.CONTENTS[4]
1420             else:
1421                 expected_contents = self.CONTENTS[3]
1422             new_versionid = smap.best_recoverable_version()
1423             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1424             d2 = self._fn.download_version(smap, new_versionid)
1425             d2.addCallback(self.failUnlessEqual, expected_contents)
1426             return d2
1427         d.addCallback(_check_smap)
1428         return d
1429
1430     def test_non_merge(self):
1431         self.old_shares = []
1432         d = self.publish_multiple()
1433         # repair should not refuse a repair that doesn't need to merge. In
1434         # this case, we combine v2 with v3. The repair should ignore v2 and
1435         # copy v3 into a new v5.
1436         d.addCallback(lambda res:
1437                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1438                                           1:3,3:3,5:3,7:3,9:3}))
1439         d.addCallback(lambda res: self._fn.check(Monitor()))
1440         d.addCallback(lambda check_results: self._fn.repair(check_results))
1441         # this should give us 10 shares of v3
1442         def _check_repair_results(rres):
1443             self.failUnless(rres.get_successful())
1444             pass # TODO
1445         d.addCallback(_check_repair_results)
1446         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1447         def _check_smap(smap):
1448             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1449             self.failIf(smap.unrecoverable_versions())
1450             # now, which should have won?
1451             expected_contents = self.CONTENTS[3]
1452             new_versionid = smap.best_recoverable_version()
1453             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1454             d2 = self._fn.download_version(smap, new_versionid)
1455             d2.addCallback(self.failUnlessEqual, expected_contents)
1456             return d2
1457         d.addCallback(_check_smap)
1458         return d
1459
1460     def get_roothash_for(self, index):
1461         # return the roothash for the first share we see in the saved set
1462         shares = self._copied_shares[index]
1463         for peerid in shares:
1464             for shnum in shares[peerid]:
1465                 share = shares[peerid][shnum]
1466                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1467                           unpack_header(share)
1468                 return root_hash
1469
1470     def test_check_and_repair_readcap(self):
1471         # we can't currently repair from a mutable readcap: #625
1472         self.old_shares = []
1473         d = self.publish_one()
1474         d.addCallback(self.copy_shares)
1475         def _get_readcap(res):
1476             self._fn3 = self._fn.get_readonly()
1477             # also delete some shares
1478             for peerid,shares in self._storage._peers.items():
1479                 shares.pop(0, None)
1480         d.addCallback(_get_readcap)
1481         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1482         def _check_results(crr):
1483             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1484             # we should detect the unhealthy, but skip over mutable-readcap
1485             # repairs until #625 is fixed
1486             self.failIf(crr.get_pre_repair_results().is_healthy())
1487             self.failIf(crr.get_repair_attempted())
1488             self.failIf(crr.get_post_repair_results().is_healthy())
1489         d.addCallback(_check_results)
1490         return d
1491
1492 class DevNullDictionary(dict):
1493     def __setitem__(self, key, value):
1494         return
1495
1496 class MultipleEncodings(unittest.TestCase):
1497     def setUp(self):
1498         self.CONTENTS = "New contents go here"
1499         self._storage = FakeStorage()
1500         self._nodemaker = make_nodemaker(self._storage, num_peers=20)
1501         self._storage_broker = self._nodemaker.storage_broker
1502         d = self._nodemaker.create_mutable_file(self.CONTENTS)
1503         def _created(node):
1504             self._fn = node
1505         d.addCallback(_created)
1506         return d
1507
1508     def _encode(self, k, n, data):
1509         # encode 'data' into a peerid->shares dict.
1510
1511         fn = self._fn
1512         # disable the nodecache, since for these tests we explicitly need
1513         # multiple nodes pointing at the same file
1514         self._nodemaker._node_cache = DevNullDictionary()
1515         fn2 = self._nodemaker.create_from_cap(fn.get_uri())
1516         # then we copy over other fields that are normally fetched from the
1517         # existing shares
1518         fn2._pubkey = fn._pubkey
1519         fn2._privkey = fn._privkey
1520         fn2._encprivkey = fn._encprivkey
1521         # and set the encoding parameters to something completely different
1522         fn2._required_shares = k
1523         fn2._total_shares = n
1524
1525         s = self._storage
1526         s._peers = {} # clear existing storage
1527         p2 = Publish(fn2, self._storage_broker, None)
1528         d = p2.publish(data)
1529         def _published(res):
1530             shares = s._peers
1531             s._peers = {}
1532             return shares
1533         d.addCallback(_published)
1534         return d
1535
1536     def make_servermap(self, mode=MODE_READ, oldmap=None):
1537         if oldmap is None:
1538             oldmap = ServerMap()
1539         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
1540                                oldmap, mode)
1541         d = smu.update()
1542         return d
1543
1544     def test_multiple_encodings(self):
1545         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1546         # then mix up the shares, to make sure that download survives seeing
1547         # a variety of encodings. This is actually kind of tricky to set up.
1548
1549         contents1 = "Contents for encoding 1 (3-of-10) go here"
1550         contents2 = "Contents for encoding 2 (4-of-9) go here"
1551         contents3 = "Contents for encoding 3 (4-of-7) go here"
1552
1553         # we make a retrieval object that doesn't know what encoding
1554         # parameters to use
1555         fn3 = self._nodemaker.create_from_cap(self._fn.get_uri())
1556
1557         # now we upload a file through fn1, and grab its shares
1558         d = self._encode(3, 10, contents1)
1559         def _encoded_1(shares):
1560             self._shares1 = shares
1561         d.addCallback(_encoded_1)
1562         d.addCallback(lambda res: self._encode(4, 9, contents2))
1563         def _encoded_2(shares):
1564             self._shares2 = shares
1565         d.addCallback(_encoded_2)
1566         d.addCallback(lambda res: self._encode(4, 7, contents3))
1567         def _encoded_3(shares):
1568             self._shares3 = shares
1569         d.addCallback(_encoded_3)
1570
1571         def _merge(res):
1572             log.msg("merging sharelists")
1573             # we merge the shares from the two sets, leaving each shnum in
1574             # its original location, but using a share from set1 or set2
1575             # according to the following sequence:
1576             #
1577             #  4-of-9  a  s2
1578             #  4-of-9  b  s2
1579             #  4-of-7  c   s3
1580             #  4-of-9  d  s2
1581             #  3-of-9  e s1
1582             #  3-of-9  f s1
1583             #  3-of-9  g s1
1584             #  4-of-9  h  s2
1585             #
1586             # so that neither form can be recovered until fetch [f], at which
1587             # point version-s1 (the 3-of-10 form) should be recoverable. If
1588             # the implementation latches on to the first version it sees,
1589             # then s2 will be recoverable at fetch [g].
1590
1591             # Later, when we implement code that handles multiple versions,
1592             # we can use this framework to assert that all recoverable
1593             # versions are retrieved, and test that 'epsilon' does its job
1594
1595             places = [2, 2, 3, 2, 1, 1, 1, 2]
1596
1597             sharemap = {}
1598             sb = self._storage_broker
1599
1600             for peerid in sorted(sb.get_all_serverids()):
1601                 for shnum in self._shares1.get(peerid, {}):
1602                     if shnum < len(places):
1603                         which = places[shnum]
1604                     else:
1605                         which = "x"
1606                     self._storage._peers[peerid] = peers = {}
1607                     in_1 = shnum in self._shares1[peerid]
1608                     in_2 = shnum in self._shares2.get(peerid, {})
1609                     in_3 = shnum in self._shares3.get(peerid, {})
1610                     if which == 1:
1611                         if in_1:
1612                             peers[shnum] = self._shares1[peerid][shnum]
1613                             sharemap[shnum] = peerid
1614                     elif which == 2:
1615                         if in_2:
1616                             peers[shnum] = self._shares2[peerid][shnum]
1617                             sharemap[shnum] = peerid
1618                     elif which == 3:
1619                         if in_3:
1620                             peers[shnum] = self._shares3[peerid][shnum]
1621                             sharemap[shnum] = peerid
1622
1623             # we don't bother placing any other shares
1624             # now sort the sequence so that share 0 is returned first
1625             new_sequence = [sharemap[shnum]
1626                             for shnum in sorted(sharemap.keys())]
1627             self._storage._sequence = new_sequence
1628             log.msg("merge done")
1629         d.addCallback(_merge)
1630         d.addCallback(lambda res: fn3.download_best_version())
1631         def _retrieved(new_contents):
1632             # the current specified behavior is "first version recoverable"
1633             self.failUnlessEqual(new_contents, contents1)
1634         d.addCallback(_retrieved)
1635         return d
1636
1637
1638 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1639
1640     def setUp(self):
1641         return self.publish_multiple()
1642
1643     def test_multiple_versions(self):
1644         # if we see a mix of versions in the grid, download_best_version
1645         # should get the latest one
1646         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1647         d = self._fn.download_best_version()
1648         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1649         # and the checker should report problems
1650         d.addCallback(lambda res: self._fn.check(Monitor()))
1651         d.addCallback(self.check_bad, "test_multiple_versions")
1652
1653         # but if everything is at version 2, that's what we should download
1654         d.addCallback(lambda res:
1655                       self._set_versions(dict([(i,2) for i in range(10)])))
1656         d.addCallback(lambda res: self._fn.download_best_version())
1657         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1658         # if exactly one share is at version 3, we should still get v2
1659         d.addCallback(lambda res:
1660                       self._set_versions({0:3}))
1661         d.addCallback(lambda res: self._fn.download_best_version())
1662         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1663         # but the servermap should see the unrecoverable version. This
1664         # depends upon the single newer share being queried early.
1665         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1666         def _check_smap(smap):
1667             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1668             newer = smap.unrecoverable_newer_versions()
1669             self.failUnlessEqual(len(newer), 1)
1670             verinfo, health = newer.items()[0]
1671             self.failUnlessEqual(verinfo[0], 4)
1672             self.failUnlessEqual(health, (1,3))
1673             self.failIf(smap.needs_merge())
1674         d.addCallback(_check_smap)
1675         # if we have a mix of two parallel versions (s4a and s4b), we could
1676         # recover either
1677         d.addCallback(lambda res:
1678                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1679                                           1:4,3:4,5:4,7:4,9:4}))
1680         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1681         def _check_smap_mixed(smap):
1682             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1683             newer = smap.unrecoverable_newer_versions()
1684             self.failUnlessEqual(len(newer), 0)
1685             self.failUnless(smap.needs_merge())
1686         d.addCallback(_check_smap_mixed)
1687         d.addCallback(lambda res: self._fn.download_best_version())
1688         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1689                                                   res == self.CONTENTS[4]))
1690         return d
1691
1692     def test_replace(self):
1693         # if we see a mix of versions in the grid, we should be able to
1694         # replace them all with a newer version
1695
1696         # if exactly one share is at version 3, we should download (and
1697         # replace) v2, and the result should be v4. Note that the index we
1698         # give to _set_versions is different than the sequence number.
1699         target = dict([(i,2) for i in range(10)]) # seqnum3
1700         target[0] = 3 # seqnum4
1701         self._set_versions(target)
1702
1703         def _modify(oldversion, servermap, first_time):
1704             return oldversion + " modified"
1705         d = self._fn.modify(_modify)
1706         d.addCallback(lambda res: self._fn.download_best_version())
1707         expected = self.CONTENTS[2] + " modified"
1708         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1709         # and the servermap should indicate that the outlier was replaced too
1710         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1711         def _check_smap(smap):
1712             self.failUnlessEqual(smap.highest_seqnum(), 5)
1713             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1714             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1715         d.addCallback(_check_smap)
1716         return d
1717
1718
1719 class Utils(unittest.TestCase):
1720     def test_cache(self):
1721         c = ResponseCache()
1722         # xdata = base62.b2a(os.urandom(100))[:100]
1723         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1724         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1725         c.add("v1", 1, 0, xdata)
1726         c.add("v1", 1, 2000, ydata)
1727         self.failUnlessEqual(c.read("v2", 1, 10, 11), None)
1728         self.failUnlessEqual(c.read("v1", 2, 10, 11), None)
1729         self.failUnlessEqual(c.read("v1", 1, 0, 10), xdata[:10])
1730         self.failUnlessEqual(c.read("v1", 1, 90, 10), xdata[90:])
1731         self.failUnlessEqual(c.read("v1", 1, 300, 10), None)
1732         self.failUnlessEqual(c.read("v1", 1, 2050, 5), ydata[50:55])
1733         self.failUnlessEqual(c.read("v1", 1, 0, 101), None)
1734         self.failUnlessEqual(c.read("v1", 1, 99, 1), xdata[99:100])
1735         self.failUnlessEqual(c.read("v1", 1, 100, 1), None)
1736         self.failUnlessEqual(c.read("v1", 1, 1990, 9), None)
1737         self.failUnlessEqual(c.read("v1", 1, 1990, 10), None)
1738         self.failUnlessEqual(c.read("v1", 1, 1990, 11), None)
1739         self.failUnlessEqual(c.read("v1", 1, 1990, 15), None)
1740         self.failUnlessEqual(c.read("v1", 1, 1990, 19), None)
1741         self.failUnlessEqual(c.read("v1", 1, 1990, 20), None)
1742         self.failUnlessEqual(c.read("v1", 1, 1990, 21), None)
1743         self.failUnlessEqual(c.read("v1", 1, 1990, 25), None)
1744         self.failUnlessEqual(c.read("v1", 1, 1999, 25), None)
1745
1746         # test joining fragments
1747         c = ResponseCache()
1748         c.add("v1", 1, 0, xdata[:10])
1749         c.add("v1", 1, 10, xdata[10:20])
1750         self.failUnlessEqual(c.read("v1", 1, 0, 20), xdata[:20])
1751
1752 class Exceptions(unittest.TestCase):
1753     def test_repr(self):
1754         nmde = NeedMoreDataError(100, 50, 100)
1755         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1756         ucwe = UncoordinatedWriteError()
1757         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1758
1759 class SameKeyGenerator:
1760     def __init__(self, pubkey, privkey):
1761         self.pubkey = pubkey
1762         self.privkey = privkey
1763     def generate(self, keysize=None):
1764         return defer.succeed( (self.pubkey, self.privkey) )
1765
1766 class FirstServerGetsKilled:
1767     done = False
1768     def notify(self, retval, wrapper, methname):
1769         if not self.done:
1770             wrapper.broken = True
1771             self.done = True
1772         return retval
1773
1774 class FirstServerGetsDeleted:
1775     def __init__(self):
1776         self.done = False
1777         self.silenced = None
1778     def notify(self, retval, wrapper, methname):
1779         if not self.done:
1780             # this query will work, but later queries should think the share
1781             # has been deleted
1782             self.done = True
1783             self.silenced = wrapper
1784             return retval
1785         if wrapper == self.silenced:
1786             assert methname == "slot_testv_and_readv_and_writev"
1787             return (True, {})
1788         return retval
1789
1790 class Problems(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
1791     def test_publish_surprise(self):
1792         self.basedir = "mutable/Problems/test_publish_surprise"
1793         self.set_up_grid()
1794         nm = self.g.clients[0].nodemaker
1795         d = nm.create_mutable_file("contents 1")
1796         def _created(n):
1797             d = defer.succeed(None)
1798             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1799             def _got_smap1(smap):
1800                 # stash the old state of the file
1801                 self.old_map = smap
1802             d.addCallback(_got_smap1)
1803             # then modify the file, leaving the old map untouched
1804             d.addCallback(lambda res: log.msg("starting winning write"))
1805             d.addCallback(lambda res: n.overwrite("contents 2"))
1806             # now attempt to modify the file with the old servermap. This
1807             # will look just like an uncoordinated write, in which every
1808             # single share got updated between our mapupdate and our publish
1809             d.addCallback(lambda res: log.msg("starting doomed write"))
1810             d.addCallback(lambda res:
1811                           self.shouldFail(UncoordinatedWriteError,
1812                                           "test_publish_surprise", None,
1813                                           n.upload,
1814                                           "contents 2a", self.old_map))
1815             return d
1816         d.addCallback(_created)
1817         return d
1818
1819     def test_retrieve_surprise(self):
1820         self.basedir = "mutable/Problems/test_retrieve_surprise"
1821         self.set_up_grid()
1822         nm = self.g.clients[0].nodemaker
1823         d = nm.create_mutable_file("contents 1")
1824         def _created(n):
1825             d = defer.succeed(None)
1826             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1827             def _got_smap1(smap):
1828                 # stash the old state of the file
1829                 self.old_map = smap
1830             d.addCallback(_got_smap1)
1831             # then modify the file, leaving the old map untouched
1832             d.addCallback(lambda res: log.msg("starting winning write"))
1833             d.addCallback(lambda res: n.overwrite("contents 2"))
1834             # now attempt to retrieve the old version with the old servermap.
1835             # This will look like someone has changed the file since we
1836             # updated the servermap.
1837             d.addCallback(lambda res: n._cache._clear())
1838             d.addCallback(lambda res: log.msg("starting doomed read"))
1839             d.addCallback(lambda res:
1840                           self.shouldFail(NotEnoughSharesError,
1841                                           "test_retrieve_surprise",
1842                                           "ran out of peers: have 0 shares (k=3)",
1843                                           n.download_version,
1844                                           self.old_map,
1845                                           self.old_map.best_recoverable_version(),
1846                                           ))
1847             return d
1848         d.addCallback(_created)
1849         return d
1850
1851     def test_unexpected_shares(self):
1852         # upload the file, take a servermap, shut down one of the servers,
1853         # upload it again (causing shares to appear on a new server), then
1854         # upload using the old servermap. The last upload should fail with an
1855         # UncoordinatedWriteError, because of the shares that didn't appear
1856         # in the servermap.
1857         self.basedir = "mutable/Problems/test_unexpected_shares"
1858         self.set_up_grid()
1859         nm = self.g.clients[0].nodemaker
1860         d = nm.create_mutable_file("contents 1")
1861         def _created(n):
1862             d = defer.succeed(None)
1863             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1864             def _got_smap1(smap):
1865                 # stash the old state of the file
1866                 self.old_map = smap
1867                 # now shut down one of the servers
1868                 peer0 = list(smap.make_sharemap()[0])[0]
1869                 self.g.remove_server(peer0)
1870                 # then modify the file, leaving the old map untouched
1871                 log.msg("starting winning write")
1872                 return n.overwrite("contents 2")
1873             d.addCallback(_got_smap1)
1874             # now attempt to modify the file with the old servermap. This
1875             # will look just like an uncoordinated write, in which every
1876             # single share got updated between our mapupdate and our publish
1877             d.addCallback(lambda res: log.msg("starting doomed write"))
1878             d.addCallback(lambda res:
1879                           self.shouldFail(UncoordinatedWriteError,
1880                                           "test_surprise", None,
1881                                           n.upload,
1882                                           "contents 2a", self.old_map))
1883             return d
1884         d.addCallback(_created)
1885         return d
1886
1887     def test_bad_server(self):
1888         # Break one server, then create the file: the initial publish should
1889         # complete with an alternate server. Breaking a second server should
1890         # not prevent an update from succeeding either.
1891         self.basedir = "mutable/Problems/test_bad_server"
1892         self.set_up_grid()
1893         nm = self.g.clients[0].nodemaker
1894
1895         # to make sure that one of the initial peers is broken, we have to
1896         # get creative. We create an RSA key and compute its storage-index.
1897         # Then we make a KeyGenerator that always returns that one key, and
1898         # use it to create the mutable file. This will get easier when we can
1899         # use #467 static-server-selection to disable permutation and force
1900         # the choice of server for share[0].
1901
1902         d = nm.key_generator.generate(522)
1903         def _got_key( (pubkey, privkey) ):
1904             nm.key_generator = SameKeyGenerator(pubkey, privkey)
1905             pubkey_s = pubkey.serialize()
1906             privkey_s = privkey.serialize()
1907             u = uri.WriteableSSKFileURI(ssk_writekey_hash(privkey_s),
1908                                         ssk_pubkey_fingerprint_hash(pubkey_s))
1909             self._storage_index = u.get_storage_index()
1910         d.addCallback(_got_key)
1911         def _break_peer0(res):
1912             si = self._storage_index
1913             servers = nm.storage_broker.get_servers_for_psi(si)
1914             self.g.break_server(servers[0].get_serverid())
1915             self.server1 = servers[1]
1916         d.addCallback(_break_peer0)
1917         # now "create" the file, using the pre-established key, and let the
1918         # initial publish finally happen
1919         d.addCallback(lambda res: nm.create_mutable_file("contents 1"))
1920         # that ought to work
1921         def _got_node(n):
1922             d = n.download_best_version()
1923             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1924             # now break the second peer
1925             def _break_peer1(res):
1926                 self.g.break_server(self.server1.get_serverid())
1927             d.addCallback(_break_peer1)
1928             d.addCallback(lambda res: n.overwrite("contents 2"))
1929             # that ought to work too
1930             d.addCallback(lambda res: n.download_best_version())
1931             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1932             def _explain_error(f):
1933                 print f
1934                 if f.check(NotEnoughServersError):
1935                     print "first_error:", f.value.first_error
1936                 return f
1937             d.addErrback(_explain_error)
1938             return d
1939         d.addCallback(_got_node)
1940         return d
1941
1942     def test_bad_server_overlap(self):
1943         # like test_bad_server, but with no extra unused servers to fall back
1944         # upon. This means that we must re-use a server which we've already
1945         # used. If we don't remember the fact that we sent them one share
1946         # already, we'll mistakenly think we're experiencing an
1947         # UncoordinatedWriteError.
1948
1949         # Break one server, then create the file: the initial publish should
1950         # complete with an alternate server. Breaking a second server should
1951         # not prevent an update from succeeding either.
1952         self.basedir = "mutable/Problems/test_bad_server_overlap"
1953         self.set_up_grid()
1954         nm = self.g.clients[0].nodemaker
1955         sb = nm.storage_broker
1956
1957         peerids = [s.get_serverid() for s in sb.get_connected_servers()]
1958         self.g.break_server(peerids[0])
1959
1960         d = nm.create_mutable_file("contents 1")
1961         def _created(n):
1962             d = n.download_best_version()
1963             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1964             # now break one of the remaining servers
1965             def _break_second_server(res):
1966                 self.g.break_server(peerids[1])
1967             d.addCallback(_break_second_server)
1968             d.addCallback(lambda res: n.overwrite("contents 2"))
1969             # that ought to work too
1970             d.addCallback(lambda res: n.download_best_version())
1971             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1972             return d
1973         d.addCallback(_created)
1974         return d
1975
1976     def test_publish_all_servers_bad(self):
1977         # Break all servers: the publish should fail
1978         self.basedir = "mutable/Problems/test_publish_all_servers_bad"
1979         self.set_up_grid()
1980         nm = self.g.clients[0].nodemaker
1981         for s in nm.storage_broker.get_connected_servers():
1982             s.get_rref().broken = True
1983
1984         d = self.shouldFail(NotEnoughServersError,
1985                             "test_publish_all_servers_bad",
1986                             "Ran out of non-bad servers",
1987                             nm.create_mutable_file, "contents")
1988         return d
1989
1990     def test_publish_no_servers(self):
1991         # no servers at all: the publish should fail
1992         self.basedir = "mutable/Problems/test_publish_no_servers"
1993         self.set_up_grid(num_servers=0)
1994         nm = self.g.clients[0].nodemaker
1995
1996         d = self.shouldFail(NotEnoughServersError,
1997                             "test_publish_no_servers",
1998                             "Ran out of non-bad servers",
1999                             nm.create_mutable_file, "contents")
2000         return d
2001     test_publish_no_servers.timeout = 30
2002
2003
2004     def test_privkey_query_error(self):
2005         # when a servermap is updated with MODE_WRITE, it tries to get the
2006         # privkey. Something might go wrong during this query attempt.
2007         # Exercise the code in _privkey_query_failed which tries to handle
2008         # such an error.
2009         self.basedir = "mutable/Problems/test_privkey_query_error"
2010         self.set_up_grid(num_servers=20)
2011         nm = self.g.clients[0].nodemaker
2012         nm._node_cache = DevNullDictionary() # disable the nodecache
2013
2014         # we need some contents that are large enough to push the privkey out
2015         # of the early part of the file
2016         LARGE = "These are Larger contents" * 2000 # about 50KB
2017         d = nm.create_mutable_file(LARGE)
2018         def _created(n):
2019             self.uri = n.get_uri()
2020             self.n2 = nm.create_from_cap(self.uri)
2021
2022             # When a mapupdate is performed on a node that doesn't yet know
2023             # the privkey, a short read is sent to a batch of servers, to get
2024             # the verinfo and (hopefully, if the file is short enough) the
2025             # encprivkey. Our file is too large to let this first read
2026             # contain the encprivkey. Each non-encprivkey-bearing response
2027             # that arrives (until the node gets the encprivkey) will trigger
2028             # a second read to specifically read the encprivkey.
2029             #
2030             # So, to exercise this case:
2031             #  1. notice which server gets a read() call first
2032             #  2. tell that server to start throwing errors
2033             killer = FirstServerGetsKilled()
2034             for s in nm.storage_broker.get_connected_servers():
2035                 s.get_rref().post_call_notifier = killer.notify
2036         d.addCallback(_created)
2037
2038         # now we update a servermap from a new node (which doesn't have the
2039         # privkey yet, forcing it to use a separate privkey query). Note that
2040         # the map-update will succeed, since we'll just get a copy from one
2041         # of the other shares.
2042         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2043
2044         return d
2045
2046     def test_privkey_query_missing(self):
2047         # like test_privkey_query_error, but the shares are deleted by the
2048         # second query, instead of raising an exception.
2049         self.basedir = "mutable/Problems/test_privkey_query_missing"
2050         self.set_up_grid(num_servers=20)
2051         nm = self.g.clients[0].nodemaker
2052         LARGE = "These are Larger contents" * 2000 # about 50KB
2053         nm._node_cache = DevNullDictionary() # disable the nodecache
2054
2055         d = nm.create_mutable_file(LARGE)
2056         def _created(n):
2057             self.uri = n.get_uri()
2058             self.n2 = nm.create_from_cap(self.uri)
2059             deleter = FirstServerGetsDeleted()
2060             for s in nm.storage_broker.get_connected_servers():
2061                 s.get_rref().post_call_notifier = deleter.notify
2062         d.addCallback(_created)
2063         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2064         return d