]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
Simplify immutable download API: use just filenode.read(consumer, offset, size)
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from allmydata import uri, client
7 from allmydata.nodemaker import NodeMaker
8 from allmydata.util import base32
9 from allmydata.util.idlib import shortnodeid_b2a
10 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
11      ssk_pubkey_fingerprint_hash
12 from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
13      NotEnoughSharesError
14 from allmydata.monitor import Monitor
15 from allmydata.test.common import ShouldFailMixin
16 from allmydata.test.no_network import GridTestMixin
17 from foolscap.api import eventually, fireEventually
18 from foolscap.logging import log
19 from allmydata.storage_client import StorageFarmBroker
20
21 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
22 from allmydata.mutable.common import ResponseCache, \
23      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
24      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
25      NotEnoughServersError, CorruptShareError
26 from allmydata.mutable.retrieve import Retrieve
27 from allmydata.mutable.publish import Publish
28 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
29 from allmydata.mutable.layout import unpack_header, unpack_share
30 from allmydata.mutable.repairer import MustForceRepairError
31
32 import common_util as testutil
33
34 # this "FakeStorage" exists to put the share data in RAM and avoid using real
35 # network connections, both to speed up the tests and to reduce the amount of
36 # non-mutable.py code being exercised.
37
38 class FakeStorage:
39     # this class replaces the collection of storage servers, allowing the
40     # tests to examine and manipulate the published shares. It also lets us
41     # control the order in which read queries are answered, to exercise more
42     # of the error-handling code in Retrieve .
43     #
44     # Note that we ignore the storage index: this FakeStorage instance can
45     # only be used for a single storage index.
46
47
48     def __init__(self):
49         self._peers = {}
50         # _sequence is used to cause the responses to occur in a specific
51         # order. If it is in use, then we will defer queries instead of
52         # answering them right away, accumulating the Deferreds in a dict. We
53         # don't know exactly how many queries we'll get, so exactly one
54         # second after the first query arrives, we will release them all (in
55         # order).
56         self._sequence = None
57         self._pending = {}
58         self._pending_timer = None
59
60     def read(self, peerid, storage_index):
61         shares = self._peers.get(peerid, {})
62         if self._sequence is None:
63             return defer.succeed(shares)
64         d = defer.Deferred()
65         if not self._pending:
66             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
67         self._pending[peerid] = (d, shares)
68         return d
69
70     def _fire_readers(self):
71         self._pending_timer = None
72         pending = self._pending
73         self._pending = {}
74         extra = []
75         for peerid in self._sequence:
76             if peerid in pending:
77                 d, shares = pending.pop(peerid)
78                 eventually(d.callback, shares)
79         for (d, shares) in pending.values():
80             eventually(d.callback, shares)
81
82     def write(self, peerid, storage_index, shnum, offset, data):
83         if peerid not in self._peers:
84             self._peers[peerid] = {}
85         shares = self._peers[peerid]
86         f = StringIO()
87         f.write(shares.get(shnum, ""))
88         f.seek(offset)
89         f.write(data)
90         shares[shnum] = f.getvalue()
91
92
93 class FakeStorageServer:
94     def __init__(self, peerid, storage):
95         self.peerid = peerid
96         self.storage = storage
97         self.queries = 0
98     def callRemote(self, methname, *args, **kwargs):
99         def _call():
100             meth = getattr(self, methname)
101             return meth(*args, **kwargs)
102         d = fireEventually()
103         d.addCallback(lambda res: _call())
104         return d
105     def callRemoteOnly(self, methname, *args, **kwargs):
106         d = self.callRemote(methname, *args, **kwargs)
107         d.addBoth(lambda ignore: None)
108         pass
109
110     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
111         pass
112
113     def slot_readv(self, storage_index, shnums, readv):
114         d = self.storage.read(self.peerid, storage_index)
115         def _read(shares):
116             response = {}
117             for shnum in shares:
118                 if shnums and shnum not in shnums:
119                     continue
120                 vector = response[shnum] = []
121                 for (offset, length) in readv:
122                     assert isinstance(offset, (int, long)), offset
123                     assert isinstance(length, (int, long)), length
124                     vector.append(shares[shnum][offset:offset+length])
125             return response
126         d.addCallback(_read)
127         return d
128
129     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
130                                         tw_vectors, read_vector):
131         # always-pass: parrot the test vectors back to them.
132         readv = {}
133         for shnum, (testv, writev, new_length) in tw_vectors.items():
134             for (offset, length, op, specimen) in testv:
135                 assert op in ("le", "eq", "ge")
136             # TODO: this isn't right, the read is controlled by read_vector,
137             # not by testv
138             readv[shnum] = [ specimen
139                              for (offset, length, op, specimen)
140                              in testv ]
141             for (offset, data) in writev:
142                 self.storage.write(self.peerid, storage_index, shnum,
143                                    offset, data)
144         answer = (True, readv)
145         return fireEventually(answer)
146
147
148 def flip_bit(original, byte_offset):
149     return (original[:byte_offset] +
150             chr(ord(original[byte_offset]) ^ 0x01) +
151             original[byte_offset+1:])
152
153 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
154     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
155     # list of shnums to corrupt.
156     for peerid in s._peers:
157         shares = s._peers[peerid]
158         for shnum in shares:
159             if (shnums_to_corrupt is not None
160                 and shnum not in shnums_to_corrupt):
161                 continue
162             data = shares[shnum]
163             (version,
164              seqnum,
165              root_hash,
166              IV,
167              k, N, segsize, datalen,
168              o) = unpack_header(data)
169             if isinstance(offset, tuple):
170                 offset1, offset2 = offset
171             else:
172                 offset1 = offset
173                 offset2 = 0
174             if offset1 == "pubkey":
175                 real_offset = 107
176             elif offset1 in o:
177                 real_offset = o[offset1]
178             else:
179                 real_offset = offset1
180             real_offset = int(real_offset) + offset2 + offset_offset
181             assert isinstance(real_offset, int), offset
182             shares[shnum] = flip_bit(data, real_offset)
183     return res
184
185 def make_storagebroker(s=None, num_peers=10):
186     if not s:
187         s = FakeStorage()
188     peerids = [tagged_hash("peerid", "%d" % i)[:20]
189                for i in range(num_peers)]
190     storage_broker = StorageFarmBroker(None, True)
191     for peerid in peerids:
192         fss = FakeStorageServer(peerid, s)
193         storage_broker.test_add_server(peerid, fss)
194     return storage_broker
195
196 def make_nodemaker(s=None, num_peers=10):
197     storage_broker = make_storagebroker(s, num_peers)
198     sh = client.SecretHolder("lease secret", "convergence secret")
199     keygen = client.KeyGenerator()
200     keygen.set_default_keysize(522)
201     nodemaker = NodeMaker(storage_broker, sh, None,
202                           None, None, None,
203                           {"k": 3, "n": 10}, keygen)
204     return nodemaker
205
206 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
207     # this used to be in Publish, but we removed the limit. Some of
208     # these tests test whether the new code correctly allows files
209     # larger than the limit.
210     OLD_MAX_SEGMENT_SIZE = 3500000
211     def setUp(self):
212         self._storage = s = FakeStorage()
213         self.nodemaker = make_nodemaker(s)
214
215     def test_create(self):
216         d = self.nodemaker.create_mutable_file()
217         def _created(n):
218             self.failUnless(isinstance(n, MutableFileNode))
219             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
220             sb = self.nodemaker.storage_broker
221             peer0 = sorted(sb.get_all_serverids())[0]
222             shnums = self._storage._peers[peer0].keys()
223             self.failUnlessEqual(len(shnums), 1)
224         d.addCallback(_created)
225         return d
226
227     def test_serialize(self):
228         n = MutableFileNode(None, None, {"k": 3, "n": 10}, None)
229         calls = []
230         def _callback(*args, **kwargs):
231             self.failUnlessEqual(args, (4,) )
232             self.failUnlessEqual(kwargs, {"foo": 5})
233             calls.append(1)
234             return 6
235         d = n._do_serialized(_callback, 4, foo=5)
236         def _check_callback(res):
237             self.failUnlessEqual(res, 6)
238             self.failUnlessEqual(calls, [1])
239         d.addCallback(_check_callback)
240
241         def _errback():
242             raise ValueError("heya")
243         d.addCallback(lambda res:
244                       self.shouldFail(ValueError, "_check_errback", "heya",
245                                       n._do_serialized, _errback))
246         return d
247
248     def test_upload_and_download(self):
249         d = self.nodemaker.create_mutable_file()
250         def _created(n):
251             d = defer.succeed(None)
252             d.addCallback(lambda res: n.get_servermap(MODE_READ))
253             d.addCallback(lambda smap: smap.dump(StringIO()))
254             d.addCallback(lambda sio:
255                           self.failUnless("3-of-10" in sio.getvalue()))
256             d.addCallback(lambda res: n.overwrite("contents 1"))
257             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
258             d.addCallback(lambda res: n.download_best_version())
259             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
260             d.addCallback(lambda res: n.get_size_of_best_version())
261             d.addCallback(lambda size:
262                           self.failUnlessEqual(size, len("contents 1")))
263             d.addCallback(lambda res: n.overwrite("contents 2"))
264             d.addCallback(lambda res: n.download_best_version())
265             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
266             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
267             d.addCallback(lambda smap: n.upload("contents 3", smap))
268             d.addCallback(lambda res: n.download_best_version())
269             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
270             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
271             d.addCallback(lambda smap:
272                           n.download_version(smap,
273                                              smap.best_recoverable_version()))
274             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
275             # test a file that is large enough to overcome the
276             # mapupdate-to-retrieve data caching (i.e. make the shares larger
277             # than the default readsize, which is 2000 bytes). A 15kB file
278             # will have 5kB shares.
279             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
280             d.addCallback(lambda res: n.download_best_version())
281             d.addCallback(lambda res:
282                           self.failUnlessEqual(res, "large size file" * 1000))
283             return d
284         d.addCallback(_created)
285         return d
286
287     def test_create_with_initial_contents(self):
288         d = self.nodemaker.create_mutable_file("contents 1")
289         def _created(n):
290             d = n.download_best_version()
291             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
292             d.addCallback(lambda res: n.overwrite("contents 2"))
293             d.addCallback(lambda res: n.download_best_version())
294             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
295             return d
296         d.addCallback(_created)
297         return d
298
299     def test_create_with_initial_contents_function(self):
300         data = "initial contents"
301         def _make_contents(n):
302             self.failUnless(isinstance(n, MutableFileNode))
303             key = n.get_writekey()
304             self.failUnless(isinstance(key, str), key)
305             self.failUnlessEqual(len(key), 16) # AES key size
306             return data
307         d = self.nodemaker.create_mutable_file(_make_contents)
308         def _created(n):
309             return n.download_best_version()
310         d.addCallback(_created)
311         d.addCallback(lambda data2: self.failUnlessEqual(data2, data))
312         return d
313
314     def test_create_with_too_large_contents(self):
315         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
316         d = self.nodemaker.create_mutable_file(BIG)
317         def _created(n):
318             d = n.overwrite(BIG)
319             return d
320         d.addCallback(_created)
321         return d
322
323     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
324         d = n.get_servermap(MODE_READ)
325         d.addCallback(lambda servermap: servermap.best_recoverable_version())
326         d.addCallback(lambda verinfo:
327                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
328         return d
329
330     def test_modify(self):
331         def _modifier(old_contents, servermap, first_time):
332             return old_contents + "line2"
333         def _non_modifier(old_contents, servermap, first_time):
334             return old_contents
335         def _none_modifier(old_contents, servermap, first_time):
336             return None
337         def _error_modifier(old_contents, servermap, first_time):
338             raise ValueError("oops")
339         def _toobig_modifier(old_contents, servermap, first_time):
340             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
341         calls = []
342         def _ucw_error_modifier(old_contents, servermap, first_time):
343             # simulate an UncoordinatedWriteError once
344             calls.append(1)
345             if len(calls) <= 1:
346                 raise UncoordinatedWriteError("simulated")
347             return old_contents + "line3"
348         def _ucw_error_non_modifier(old_contents, servermap, first_time):
349             # simulate an UncoordinatedWriteError once, and don't actually
350             # modify the contents on subsequent invocations
351             calls.append(1)
352             if len(calls) <= 1:
353                 raise UncoordinatedWriteError("simulated")
354             return old_contents
355
356         d = self.nodemaker.create_mutable_file("line1")
357         def _created(n):
358             d = n.modify(_modifier)
359             d.addCallback(lambda res: n.download_best_version())
360             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
361             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
362
363             d.addCallback(lambda res: n.modify(_non_modifier))
364             d.addCallback(lambda res: n.download_best_version())
365             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
366             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
367
368             d.addCallback(lambda res: n.modify(_none_modifier))
369             d.addCallback(lambda res: n.download_best_version())
370             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
371             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
372
373             d.addCallback(lambda res:
374                           self.shouldFail(ValueError, "error_modifier", None,
375                                           n.modify, _error_modifier))
376             d.addCallback(lambda res: n.download_best_version())
377             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
378             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
379
380
381             d.addCallback(lambda res: n.download_best_version())
382             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
383             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
384
385             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
386             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
387             d.addCallback(lambda res: n.download_best_version())
388             d.addCallback(lambda res: self.failUnlessEqual(res,
389                                                            "line1line2line3"))
390             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
391
392             def _reset_ucw_error_modifier(res):
393                 calls[:] = []
394                 return res
395             d.addCallback(_reset_ucw_error_modifier)
396
397             # in practice, this n.modify call should publish twice: the first
398             # one gets a UCWE, the second does not. But our test jig (in
399             # which the modifier raises the UCWE) skips over the first one,
400             # so in this test there will be only one publish, and the seqnum
401             # will only be one larger than the previous test, not two (i.e. 4
402             # instead of 5).
403             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
404             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
405             d.addCallback(lambda res: n.download_best_version())
406             d.addCallback(lambda res: self.failUnlessEqual(res,
407                                                            "line1line2line3"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
409             d.addCallback(lambda res: n.modify(_toobig_modifier))
410             return d
411         d.addCallback(_created)
412         return d
413
414     def test_modify_backoffer(self):
415         def _modifier(old_contents, servermap, first_time):
416             return old_contents + "line2"
417         calls = []
418         def _ucw_error_modifier(old_contents, servermap, first_time):
419             # simulate an UncoordinatedWriteError once
420             calls.append(1)
421             if len(calls) <= 1:
422                 raise UncoordinatedWriteError("simulated")
423             return old_contents + "line3"
424         def _always_ucw_error_modifier(old_contents, servermap, first_time):
425             raise UncoordinatedWriteError("simulated")
426         def _backoff_stopper(node, f):
427             return f
428         def _backoff_pauser(node, f):
429             d = defer.Deferred()
430             reactor.callLater(0.5, d.callback, None)
431             return d
432
433         # the give-up-er will hit its maximum retry count quickly
434         giveuper = BackoffAgent()
435         giveuper._delay = 0.1
436         giveuper.factor = 1
437
438         d = self.nodemaker.create_mutable_file("line1")
439         def _created(n):
440             d = n.modify(_modifier)
441             d.addCallback(lambda res: n.download_best_version())
442             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
443             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
444
445             d.addCallback(lambda res:
446                           self.shouldFail(UncoordinatedWriteError,
447                                           "_backoff_stopper", None,
448                                           n.modify, _ucw_error_modifier,
449                                           _backoff_stopper))
450             d.addCallback(lambda res: n.download_best_version())
451             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
452             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
453
454             def _reset_ucw_error_modifier(res):
455                 calls[:] = []
456                 return res
457             d.addCallback(_reset_ucw_error_modifier)
458             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
459                                                _backoff_pauser))
460             d.addCallback(lambda res: n.download_best_version())
461             d.addCallback(lambda res: self.failUnlessEqual(res,
462                                                            "line1line2line3"))
463             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
464
465             d.addCallback(lambda res:
466                           self.shouldFail(UncoordinatedWriteError,
467                                           "giveuper", None,
468                                           n.modify, _always_ucw_error_modifier,
469                                           giveuper.delay))
470             d.addCallback(lambda res: n.download_best_version())
471             d.addCallback(lambda res: self.failUnlessEqual(res,
472                                                            "line1line2line3"))
473             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
474
475             return d
476         d.addCallback(_created)
477         return d
478
479     def test_upload_and_download_full_size_keys(self):
480         self.nodemaker.key_generator = client.KeyGenerator()
481         d = self.nodemaker.create_mutable_file()
482         def _created(n):
483             d = defer.succeed(None)
484             d.addCallback(lambda res: n.get_servermap(MODE_READ))
485             d.addCallback(lambda smap: smap.dump(StringIO()))
486             d.addCallback(lambda sio:
487                           self.failUnless("3-of-10" in sio.getvalue()))
488             d.addCallback(lambda res: n.overwrite("contents 1"))
489             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
490             d.addCallback(lambda res: n.download_best_version())
491             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
492             d.addCallback(lambda res: n.overwrite("contents 2"))
493             d.addCallback(lambda res: n.download_best_version())
494             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
495             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
496             d.addCallback(lambda smap: n.upload("contents 3", smap))
497             d.addCallback(lambda res: n.download_best_version())
498             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
499             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
500             d.addCallback(lambda smap:
501                           n.download_version(smap,
502                                              smap.best_recoverable_version()))
503             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
504             return d
505         d.addCallback(_created)
506         return d
507
508
509 class MakeShares(unittest.TestCase):
510     def test_encrypt(self):
511         nm = make_nodemaker()
512         CONTENTS = "some initial contents"
513         d = nm.create_mutable_file(CONTENTS)
514         def _created(fn):
515             p = Publish(fn, nm.storage_broker, None)
516             p.salt = "SALT" * 4
517             p.readkey = "\x00" * 16
518             p.newdata = CONTENTS
519             p.required_shares = 3
520             p.total_shares = 10
521             p.setup_encoding_parameters()
522             return p._encrypt_and_encode()
523         d.addCallback(_created)
524         def _done(shares_and_shareids):
525             (shares, share_ids) = shares_and_shareids
526             self.failUnlessEqual(len(shares), 10)
527             for sh in shares:
528                 self.failUnless(isinstance(sh, str))
529                 self.failUnlessEqual(len(sh), 7)
530             self.failUnlessEqual(len(share_ids), 10)
531         d.addCallback(_done)
532         return d
533
534     def test_generate(self):
535         nm = make_nodemaker()
536         CONTENTS = "some initial contents"
537         d = nm.create_mutable_file(CONTENTS)
538         def _created(fn):
539             self._fn = fn
540             p = Publish(fn, nm.storage_broker, None)
541             self._p = p
542             p.newdata = CONTENTS
543             p.required_shares = 3
544             p.total_shares = 10
545             p.setup_encoding_parameters()
546             p._new_seqnum = 3
547             p.salt = "SALT" * 4
548             # make some fake shares
549             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
550             p._privkey = fn.get_privkey()
551             p._encprivkey = fn.get_encprivkey()
552             p._pubkey = fn.get_pubkey()
553             return p._generate_shares(shares_and_ids)
554         d.addCallback(_created)
555         def _generated(res):
556             p = self._p
557             final_shares = p.shares
558             root_hash = p.root_hash
559             self.failUnlessEqual(len(root_hash), 32)
560             self.failUnless(isinstance(final_shares, dict))
561             self.failUnlessEqual(len(final_shares), 10)
562             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
563             for i,sh in final_shares.items():
564                 self.failUnless(isinstance(sh, str))
565                 # feed the share through the unpacker as a sanity-check
566                 pieces = unpack_share(sh)
567                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
568                  pubkey, signature, share_hash_chain, block_hash_tree,
569                  share_data, enc_privkey) = pieces
570                 self.failUnlessEqual(u_seqnum, 3)
571                 self.failUnlessEqual(u_root_hash, root_hash)
572                 self.failUnlessEqual(k, 3)
573                 self.failUnlessEqual(N, 10)
574                 self.failUnlessEqual(segsize, 21)
575                 self.failUnlessEqual(datalen, len(CONTENTS))
576                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
577                 sig_material = struct.pack(">BQ32s16s BBQQ",
578                                            0, p._new_seqnum, root_hash, IV,
579                                            k, N, segsize, datalen)
580                 self.failUnless(p._pubkey.verify(sig_material, signature))
581                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
582                 self.failUnless(isinstance(share_hash_chain, dict))
583                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
584                 for shnum,share_hash in share_hash_chain.items():
585                     self.failUnless(isinstance(shnum, int))
586                     self.failUnless(isinstance(share_hash, str))
587                     self.failUnlessEqual(len(share_hash), 32)
588                 self.failUnless(isinstance(block_hash_tree, list))
589                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
590                 self.failUnlessEqual(IV, "SALT"*4)
591                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
592                 self.failUnlessEqual(enc_privkey, self._fn.get_encprivkey())
593         d.addCallback(_generated)
594         return d
595
596     # TODO: when we publish to 20 peers, we should get one share per peer on 10
597     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
598     # when we publish to zero peers, we should get a NotEnoughSharesError
599
600 class PublishMixin:
601     def publish_one(self):
602         # publish a file and create shares, which can then be manipulated
603         # later.
604         self.CONTENTS = "New contents go here" * 1000
605         num_peers = 20
606         self._storage = FakeStorage()
607         self._nodemaker = make_nodemaker(self._storage)
608         self._storage_broker = self._nodemaker.storage_broker
609         d = self._nodemaker.create_mutable_file(self.CONTENTS)
610         def _created(node):
611             self._fn = node
612             self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
613         d.addCallback(_created)
614         return d
615
616     def publish_multiple(self):
617         self.CONTENTS = ["Contents 0",
618                          "Contents 1",
619                          "Contents 2",
620                          "Contents 3a",
621                          "Contents 3b"]
622         self._copied_shares = {}
623         num_peers = 20
624         self._storage = FakeStorage()
625         self._nodemaker = make_nodemaker(self._storage)
626         d = self._nodemaker.create_mutable_file(self.CONTENTS[0]) # seqnum=1
627         def _created(node):
628             self._fn = node
629             # now create multiple versions of the same file, and accumulate
630             # their shares, so we can mix and match them later.
631             d = defer.succeed(None)
632             d.addCallback(self._copy_shares, 0)
633             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
634             d.addCallback(self._copy_shares, 1)
635             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
636             d.addCallback(self._copy_shares, 2)
637             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
638             d.addCallback(self._copy_shares, 3)
639             # now we replace all the shares with version s3, and upload a new
640             # version to get s4b.
641             rollback = dict([(i,2) for i in range(10)])
642             d.addCallback(lambda res: self._set_versions(rollback))
643             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
644             d.addCallback(self._copy_shares, 4)
645             # we leave the storage in state 4
646             return d
647         d.addCallback(_created)
648         return d
649
650     def _copy_shares(self, ignored, index):
651         shares = self._storage._peers
652         # we need a deep copy
653         new_shares = {}
654         for peerid in shares:
655             new_shares[peerid] = {}
656             for shnum in shares[peerid]:
657                 new_shares[peerid][shnum] = shares[peerid][shnum]
658         self._copied_shares[index] = new_shares
659
660     def _set_versions(self, versionmap):
661         # versionmap maps shnums to which version (0,1,2,3,4) we want the
662         # share to be at. Any shnum which is left out of the map will stay at
663         # its current version.
664         shares = self._storage._peers
665         oldshares = self._copied_shares
666         for peerid in shares:
667             for shnum in shares[peerid]:
668                 if shnum in versionmap:
669                     index = versionmap[shnum]
670                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
671
672
673 class Servermap(unittest.TestCase, PublishMixin):
674     def setUp(self):
675         return self.publish_one()
676
677     def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None):
678         if fn is None:
679             fn = self._fn
680         if sb is None:
681             sb = self._storage_broker
682         smu = ServermapUpdater(fn, sb, Monitor(),
683                                ServerMap(), mode)
684         d = smu.update()
685         return d
686
687     def update_servermap(self, oldmap, mode=MODE_CHECK):
688         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
689                                oldmap, mode)
690         d = smu.update()
691         return d
692
693     def failUnlessOneRecoverable(self, sm, num_shares):
694         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
695         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
696         best = sm.best_recoverable_version()
697         self.failIfEqual(best, None)
698         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
699         self.failUnlessEqual(len(sm.shares_available()), 1)
700         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
701         shnum, peerids = sm.make_sharemap().items()[0]
702         peerid = list(peerids)[0]
703         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
704         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
705         return sm
706
707     def test_basic(self):
708         d = defer.succeed(None)
709         ms = self.make_servermap
710         us = self.update_servermap
711
712         d.addCallback(lambda res: ms(mode=MODE_CHECK))
713         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
714         d.addCallback(lambda res: ms(mode=MODE_WRITE))
715         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
716         d.addCallback(lambda res: ms(mode=MODE_READ))
717         # this more stops at k+epsilon, and epsilon=k, so 6 shares
718         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
719         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
720         # this mode stops at 'k' shares
721         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
722
723         # and can we re-use the same servermap? Note that these are sorted in
724         # increasing order of number of servers queried, since once a server
725         # gets into the servermap, we'll always ask it for an update.
726         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
727         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
728         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
729         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
730         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
731         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
732         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
733         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
734         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
735
736         return d
737
738     def test_fetch_privkey(self):
739         d = defer.succeed(None)
740         # use the sibling filenode (which hasn't been used yet), and make
741         # sure it can fetch the privkey. The file is small, so the privkey
742         # will be fetched on the first (query) pass.
743         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
744         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
745
746         # create a new file, which is large enough to knock the privkey out
747         # of the early part of the file
748         LARGE = "These are Larger contents" * 200 # about 5KB
749         d.addCallback(lambda res: self._nodemaker.create_mutable_file(LARGE))
750         def _created(large_fn):
751             large_fn2 = self._nodemaker.create_from_cap(large_fn.get_uri())
752             return self.make_servermap(MODE_WRITE, large_fn2)
753         d.addCallback(_created)
754         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
755         return d
756
757     def test_mark_bad(self):
758         d = defer.succeed(None)
759         ms = self.make_servermap
760         us = self.update_servermap
761
762         d.addCallback(lambda res: ms(mode=MODE_READ))
763         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
764         def _made_map(sm):
765             v = sm.best_recoverable_version()
766             vm = sm.make_versionmap()
767             shares = list(vm[v])
768             self.failUnlessEqual(len(shares), 6)
769             self._corrupted = set()
770             # mark the first 5 shares as corrupt, then update the servermap.
771             # The map should not have the marked shares it in any more, and
772             # new shares should be found to replace the missing ones.
773             for (shnum, peerid, timestamp) in shares:
774                 if shnum < 5:
775                     self._corrupted.add( (peerid, shnum) )
776                     sm.mark_bad_share(peerid, shnum, "")
777             return self.update_servermap(sm, MODE_WRITE)
778         d.addCallback(_made_map)
779         def _check_map(sm):
780             # this should find all 5 shares that weren't marked bad
781             v = sm.best_recoverable_version()
782             vm = sm.make_versionmap()
783             shares = list(vm[v])
784             for (peerid, shnum) in self._corrupted:
785                 peer_shares = sm.shares_on_peer(peerid)
786                 self.failIf(shnum in peer_shares,
787                             "%d was in %s" % (shnum, peer_shares))
788             self.failUnlessEqual(len(shares), 5)
789         d.addCallback(_check_map)
790         return d
791
792     def failUnlessNoneRecoverable(self, sm):
793         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
794         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
795         best = sm.best_recoverable_version()
796         self.failUnlessEqual(best, None)
797         self.failUnlessEqual(len(sm.shares_available()), 0)
798
799     def test_no_shares(self):
800         self._storage._peers = {} # delete all shares
801         ms = self.make_servermap
802         d = defer.succeed(None)
803
804         d.addCallback(lambda res: ms(mode=MODE_CHECK))
805         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
806
807         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
808         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
809
810         d.addCallback(lambda res: ms(mode=MODE_WRITE))
811         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
812
813         d.addCallback(lambda res: ms(mode=MODE_READ))
814         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
815
816         return d
817
818     def failUnlessNotQuiteEnough(self, sm):
819         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
820         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
821         best = sm.best_recoverable_version()
822         self.failUnlessEqual(best, None)
823         self.failUnlessEqual(len(sm.shares_available()), 1)
824         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
825         return sm
826
827     def test_not_quite_enough_shares(self):
828         s = self._storage
829         ms = self.make_servermap
830         num_shares = len(s._peers)
831         for peerid in s._peers:
832             s._peers[peerid] = {}
833             num_shares -= 1
834             if num_shares == 2:
835                 break
836         # now there ought to be only two shares left
837         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
838
839         d = defer.succeed(None)
840
841         d.addCallback(lambda res: ms(mode=MODE_CHECK))
842         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
843         d.addCallback(lambda sm:
844                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
845         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
846         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
847         d.addCallback(lambda res: ms(mode=MODE_WRITE))
848         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
849         d.addCallback(lambda res: ms(mode=MODE_READ))
850         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
851
852         return d
853
854
855
856 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
857     def setUp(self):
858         return self.publish_one()
859
860     def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None):
861         if oldmap is None:
862             oldmap = ServerMap()
863         if sb is None:
864             sb = self._storage_broker
865         smu = ServermapUpdater(self._fn, sb, Monitor(), oldmap, mode)
866         d = smu.update()
867         return d
868
869     def abbrev_verinfo(self, verinfo):
870         if verinfo is None:
871             return None
872         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
873          offsets_tuple) = verinfo
874         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
875
876     def abbrev_verinfo_dict(self, verinfo_d):
877         output = {}
878         for verinfo,value in verinfo_d.items():
879             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
880              offsets_tuple) = verinfo
881             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
882         return output
883
884     def dump_servermap(self, servermap):
885         print "SERVERMAP", servermap
886         print "RECOVERABLE", [self.abbrev_verinfo(v)
887                               for v in servermap.recoverable_versions()]
888         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
889         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
890
891     def do_download(self, servermap, version=None):
892         if version is None:
893             version = servermap.best_recoverable_version()
894         r = Retrieve(self._fn, servermap, version)
895         return r.download()
896
897     def test_basic(self):
898         d = self.make_servermap()
899         def _do_retrieve(servermap):
900             self._smap = servermap
901             #self.dump_servermap(servermap)
902             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
903             return self.do_download(servermap)
904         d.addCallback(_do_retrieve)
905         def _retrieved(new_contents):
906             self.failUnlessEqual(new_contents, self.CONTENTS)
907         d.addCallback(_retrieved)
908         # we should be able to re-use the same servermap, both with and
909         # without updating it.
910         d.addCallback(lambda res: self.do_download(self._smap))
911         d.addCallback(_retrieved)
912         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
913         d.addCallback(lambda res: self.do_download(self._smap))
914         d.addCallback(_retrieved)
915         # clobbering the pubkey should make the servermap updater re-fetch it
916         def _clobber_pubkey(res):
917             self._fn._pubkey = None
918         d.addCallback(_clobber_pubkey)
919         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
920         d.addCallback(lambda res: self.do_download(self._smap))
921         d.addCallback(_retrieved)
922         return d
923
924     def test_all_shares_vanished(self):
925         d = self.make_servermap()
926         def _remove_shares(servermap):
927             for shares in self._storage._peers.values():
928                 shares.clear()
929             d1 = self.shouldFail(NotEnoughSharesError,
930                                  "test_all_shares_vanished",
931                                  "ran out of peers",
932                                  self.do_download, servermap)
933             return d1
934         d.addCallback(_remove_shares)
935         return d
936
937     def test_no_servers(self):
938         sb2 = make_storagebroker(num_peers=0)
939         # if there are no servers, then a MODE_READ servermap should come
940         # back empty
941         d = self.make_servermap(sb=sb2)
942         def _check_servermap(servermap):
943             self.failUnlessEqual(servermap.best_recoverable_version(), None)
944             self.failIf(servermap.recoverable_versions())
945             self.failIf(servermap.unrecoverable_versions())
946             self.failIf(servermap.all_peers())
947         d.addCallback(_check_servermap)
948         return d
949     test_no_servers.timeout = 15
950
951     def test_no_servers_download(self):
952         sb2 = make_storagebroker(num_peers=0)
953         self._fn._storage_broker = sb2
954         d = self.shouldFail(UnrecoverableFileError,
955                             "test_no_servers_download",
956                             "no recoverable versions",
957                             self._fn.download_best_version)
958         def _restore(res):
959             # a failed download that occurs while we aren't connected to
960             # anybody should not prevent a subsequent download from working.
961             # This isn't quite the webapi-driven test that #463 wants, but it
962             # should be close enough.
963             self._fn._storage_broker = self._storage_broker
964             return self._fn.download_best_version()
965         def _retrieved(new_contents):
966             self.failUnlessEqual(new_contents, self.CONTENTS)
967         d.addCallback(_restore)
968         d.addCallback(_retrieved)
969         return d
970     test_no_servers_download.timeout = 15
971
972     def _test_corrupt_all(self, offset, substring,
973                           should_succeed=False, corrupt_early=True,
974                           failure_checker=None):
975         d = defer.succeed(None)
976         if corrupt_early:
977             d.addCallback(corrupt, self._storage, offset)
978         d.addCallback(lambda res: self.make_servermap())
979         if not corrupt_early:
980             d.addCallback(corrupt, self._storage, offset)
981         def _do_retrieve(servermap):
982             ver = servermap.best_recoverable_version()
983             if ver is None and not should_succeed:
984                 # no recoverable versions == not succeeding. The problem
985                 # should be noted in the servermap's list of problems.
986                 if substring:
987                     allproblems = [str(f) for f in servermap.problems]
988                     self.failUnlessIn(substring, "".join(allproblems))
989                 return servermap
990             if should_succeed:
991                 d1 = self._fn.download_version(servermap, ver)
992                 d1.addCallback(lambda new_contents:
993                                self.failUnlessEqual(new_contents, self.CONTENTS))
994             else:
995                 d1 = self.shouldFail(NotEnoughSharesError,
996                                      "_corrupt_all(offset=%s)" % (offset,),
997                                      substring,
998                                      self._fn.download_version, servermap, ver)
999             if failure_checker:
1000                 d1.addCallback(failure_checker)
1001             d1.addCallback(lambda res: servermap)
1002             return d1
1003         d.addCallback(_do_retrieve)
1004         return d
1005
1006     def test_corrupt_all_verbyte(self):
1007         # when the version byte is not 0, we hit an UnknownVersionError error
1008         # in unpack_share().
1009         d = self._test_corrupt_all(0, "UnknownVersionError")
1010         def _check_servermap(servermap):
1011             # and the dump should mention the problems
1012             s = StringIO()
1013             dump = servermap.dump(s).getvalue()
1014             self.failUnless("10 PROBLEMS" in dump, dump)
1015         d.addCallback(_check_servermap)
1016         return d
1017
1018     def test_corrupt_all_seqnum(self):
1019         # a corrupt sequence number will trigger a bad signature
1020         return self._test_corrupt_all(1, "signature is invalid")
1021
1022     def test_corrupt_all_R(self):
1023         # a corrupt root hash will trigger a bad signature
1024         return self._test_corrupt_all(9, "signature is invalid")
1025
1026     def test_corrupt_all_IV(self):
1027         # a corrupt salt/IV will trigger a bad signature
1028         return self._test_corrupt_all(41, "signature is invalid")
1029
1030     def test_corrupt_all_k(self):
1031         # a corrupt 'k' will trigger a bad signature
1032         return self._test_corrupt_all(57, "signature is invalid")
1033
1034     def test_corrupt_all_N(self):
1035         # a corrupt 'N' will trigger a bad signature
1036         return self._test_corrupt_all(58, "signature is invalid")
1037
1038     def test_corrupt_all_segsize(self):
1039         # a corrupt segsize will trigger a bad signature
1040         return self._test_corrupt_all(59, "signature is invalid")
1041
1042     def test_corrupt_all_datalen(self):
1043         # a corrupt data length will trigger a bad signature
1044         return self._test_corrupt_all(67, "signature is invalid")
1045
1046     def test_corrupt_all_pubkey(self):
1047         # a corrupt pubkey won't match the URI's fingerprint. We need to
1048         # remove the pubkey from the filenode, or else it won't bother trying
1049         # to update it.
1050         self._fn._pubkey = None
1051         return self._test_corrupt_all("pubkey",
1052                                       "pubkey doesn't match fingerprint")
1053
1054     def test_corrupt_all_sig(self):
1055         # a corrupt signature is a bad one
1056         # the signature runs from about [543:799], depending upon the length
1057         # of the pubkey
1058         return self._test_corrupt_all("signature", "signature is invalid")
1059
1060     def test_corrupt_all_share_hash_chain_number(self):
1061         # a corrupt share hash chain entry will show up as a bad hash. If we
1062         # mangle the first byte, that will look like a bad hash number,
1063         # causing an IndexError
1064         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1065
1066     def test_corrupt_all_share_hash_chain_hash(self):
1067         # a corrupt share hash chain entry will show up as a bad hash. If we
1068         # mangle a few bytes in, that will look like a bad hash.
1069         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1070
1071     def test_corrupt_all_block_hash_tree(self):
1072         return self._test_corrupt_all("block_hash_tree",
1073                                       "block hash tree failure")
1074
1075     def test_corrupt_all_block(self):
1076         return self._test_corrupt_all("share_data", "block hash tree failure")
1077
1078     def test_corrupt_all_encprivkey(self):
1079         # a corrupted privkey won't even be noticed by the reader, only by a
1080         # writer.
1081         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1082
1083
1084     def test_corrupt_all_seqnum_late(self):
1085         # corrupting the seqnum between mapupdate and retrieve should result
1086         # in NotEnoughSharesError, since each share will look invalid
1087         def _check(res):
1088             f = res[0]
1089             self.failUnless(f.check(NotEnoughSharesError))
1090             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1091         return self._test_corrupt_all(1, "ran out of peers",
1092                                       corrupt_early=False,
1093                                       failure_checker=_check)
1094
1095     def test_corrupt_all_block_hash_tree_late(self):
1096         def _check(res):
1097             f = res[0]
1098             self.failUnless(f.check(NotEnoughSharesError))
1099         return self._test_corrupt_all("block_hash_tree",
1100                                       "block hash tree failure",
1101                                       corrupt_early=False,
1102                                       failure_checker=_check)
1103
1104
1105     def test_corrupt_all_block_late(self):
1106         def _check(res):
1107             f = res[0]
1108             self.failUnless(f.check(NotEnoughSharesError))
1109         return self._test_corrupt_all("share_data", "block hash tree failure",
1110                                       corrupt_early=False,
1111                                       failure_checker=_check)
1112
1113
1114     def test_basic_pubkey_at_end(self):
1115         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1116         # download to succeed but forcing a bunch of retries first. Note that
1117         # this is rather pessimistic: our Retrieve process will throw away
1118         # the whole share if the pubkey is bad, even though the rest of the
1119         # share might be good.
1120
1121         self._fn._pubkey = None
1122         k = self._fn.get_required_shares()
1123         N = self._fn.get_total_shares()
1124         d = defer.succeed(None)
1125         d.addCallback(corrupt, self._storage, "pubkey",
1126                       shnums_to_corrupt=range(0, N-k))
1127         d.addCallback(lambda res: self.make_servermap())
1128         def _do_retrieve(servermap):
1129             self.failUnless(servermap.problems)
1130             self.failUnless("pubkey doesn't match fingerprint"
1131                             in str(servermap.problems[0]))
1132             ver = servermap.best_recoverable_version()
1133             r = Retrieve(self._fn, servermap, ver)
1134             return r.download()
1135         d.addCallback(_do_retrieve)
1136         d.addCallback(lambda new_contents:
1137                       self.failUnlessEqual(new_contents, self.CONTENTS))
1138         return d
1139
1140     def test_corrupt_some(self):
1141         # corrupt the data of first five shares (so the servermap thinks
1142         # they're good but retrieve marks them as bad), so that the
1143         # MODE_READ set of 6 will be insufficient, forcing node.download to
1144         # retry with more servers.
1145         corrupt(None, self._storage, "share_data", range(5))
1146         d = self.make_servermap()
1147         def _do_retrieve(servermap):
1148             ver = servermap.best_recoverable_version()
1149             self.failUnless(ver)
1150             return self._fn.download_best_version()
1151         d.addCallback(_do_retrieve)
1152         d.addCallback(lambda new_contents:
1153                       self.failUnlessEqual(new_contents, self.CONTENTS))
1154         return d
1155
1156     def test_download_fails(self):
1157         corrupt(None, self._storage, "signature")
1158         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1159                             "no recoverable versions",
1160                             self._fn.download_best_version)
1161         return d
1162
1163
1164 class CheckerMixin:
1165     def check_good(self, r, where):
1166         self.failUnless(r.is_healthy(), where)
1167         return r
1168
1169     def check_bad(self, r, where):
1170         self.failIf(r.is_healthy(), where)
1171         return r
1172
1173     def check_expected_failure(self, r, expected_exception, substring, where):
1174         for (peerid, storage_index, shnum, f) in r.problems:
1175             if f.check(expected_exception):
1176                 self.failUnless(substring in str(f),
1177                                 "%s: substring '%s' not in '%s'" %
1178                                 (where, substring, str(f)))
1179                 return
1180         self.fail("%s: didn't see expected exception %s in problems %s" %
1181                   (where, expected_exception, r.problems))
1182
1183
1184 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1185     def setUp(self):
1186         return self.publish_one()
1187
1188
1189     def test_check_good(self):
1190         d = self._fn.check(Monitor())
1191         d.addCallback(self.check_good, "test_check_good")
1192         return d
1193
1194     def test_check_no_shares(self):
1195         for shares in self._storage._peers.values():
1196             shares.clear()
1197         d = self._fn.check(Monitor())
1198         d.addCallback(self.check_bad, "test_check_no_shares")
1199         return d
1200
1201     def test_check_not_enough_shares(self):
1202         for shares in self._storage._peers.values():
1203             for shnum in shares.keys():
1204                 if shnum > 0:
1205                     del shares[shnum]
1206         d = self._fn.check(Monitor())
1207         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1208         return d
1209
1210     def test_check_all_bad_sig(self):
1211         corrupt(None, self._storage, 1) # bad sig
1212         d = self._fn.check(Monitor())
1213         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1214         return d
1215
1216     def test_check_all_bad_blocks(self):
1217         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1218         # the Checker won't notice this.. it doesn't look at actual data
1219         d = self._fn.check(Monitor())
1220         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1221         return d
1222
1223     def test_verify_good(self):
1224         d = self._fn.check(Monitor(), verify=True)
1225         d.addCallback(self.check_good, "test_verify_good")
1226         return d
1227
1228     def test_verify_all_bad_sig(self):
1229         corrupt(None, self._storage, 1) # bad sig
1230         d = self._fn.check(Monitor(), verify=True)
1231         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1232         return d
1233
1234     def test_verify_one_bad_sig(self):
1235         corrupt(None, self._storage, 1, [9]) # bad sig
1236         d = self._fn.check(Monitor(), verify=True)
1237         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1238         return d
1239
1240     def test_verify_one_bad_block(self):
1241         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1242         # the Verifier *will* notice this, since it examines every byte
1243         d = self._fn.check(Monitor(), verify=True)
1244         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1245         d.addCallback(self.check_expected_failure,
1246                       CorruptShareError, "block hash tree failure",
1247                       "test_verify_one_bad_block")
1248         return d
1249
1250     def test_verify_one_bad_sharehash(self):
1251         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1252         d = self._fn.check(Monitor(), verify=True)
1253         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1254         d.addCallback(self.check_expected_failure,
1255                       CorruptShareError, "corrupt hashes",
1256                       "test_verify_one_bad_sharehash")
1257         return d
1258
1259     def test_verify_one_bad_encprivkey(self):
1260         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1261         d = self._fn.check(Monitor(), verify=True)
1262         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1263         d.addCallback(self.check_expected_failure,
1264                       CorruptShareError, "invalid privkey",
1265                       "test_verify_one_bad_encprivkey")
1266         return d
1267
1268     def test_verify_one_bad_encprivkey_uncheckable(self):
1269         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1270         readonly_fn = self._fn.get_readonly()
1271         # a read-only node has no way to validate the privkey
1272         d = readonly_fn.check(Monitor(), verify=True)
1273         d.addCallback(self.check_good,
1274                       "test_verify_one_bad_encprivkey_uncheckable")
1275         return d
1276
1277 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1278
1279     def get_shares(self, s):
1280         all_shares = {} # maps (peerid, shnum) to share data
1281         for peerid in s._peers:
1282             shares = s._peers[peerid]
1283             for shnum in shares:
1284                 data = shares[shnum]
1285                 all_shares[ (peerid, shnum) ] = data
1286         return all_shares
1287
1288     def copy_shares(self, ignored=None):
1289         self.old_shares.append(self.get_shares(self._storage))
1290
1291     def test_repair_nop(self):
1292         self.old_shares = []
1293         d = self.publish_one()
1294         d.addCallback(self.copy_shares)
1295         d.addCallback(lambda res: self._fn.check(Monitor()))
1296         d.addCallback(lambda check_results: self._fn.repair(check_results))
1297         def _check_results(rres):
1298             self.failUnless(IRepairResults.providedBy(rres))
1299             # TODO: examine results
1300
1301             self.copy_shares()
1302
1303             initial_shares = self.old_shares[0]
1304             new_shares = self.old_shares[1]
1305             # TODO: this really shouldn't change anything. When we implement
1306             # a "minimal-bandwidth" repairer", change this test to assert:
1307             #self.failUnlessEqual(new_shares, initial_shares)
1308
1309             # all shares should be in the same place as before
1310             self.failUnlessEqual(set(initial_shares.keys()),
1311                                  set(new_shares.keys()))
1312             # but they should all be at a newer seqnum. The IV will be
1313             # different, so the roothash will be too.
1314             for key in initial_shares:
1315                 (version0,
1316                  seqnum0,
1317                  root_hash0,
1318                  IV0,
1319                  k0, N0, segsize0, datalen0,
1320                  o0) = unpack_header(initial_shares[key])
1321                 (version1,
1322                  seqnum1,
1323                  root_hash1,
1324                  IV1,
1325                  k1, N1, segsize1, datalen1,
1326                  o1) = unpack_header(new_shares[key])
1327                 self.failUnlessEqual(version0, version1)
1328                 self.failUnlessEqual(seqnum0+1, seqnum1)
1329                 self.failUnlessEqual(k0, k1)
1330                 self.failUnlessEqual(N0, N1)
1331                 self.failUnlessEqual(segsize0, segsize1)
1332                 self.failUnlessEqual(datalen0, datalen1)
1333         d.addCallback(_check_results)
1334         return d
1335
1336     def failIfSharesChanged(self, ignored=None):
1337         old_shares = self.old_shares[-2]
1338         current_shares = self.old_shares[-1]
1339         self.failUnlessEqual(old_shares, current_shares)
1340
1341     def test_merge(self):
1342         self.old_shares = []
1343         d = self.publish_multiple()
1344         # repair will refuse to merge multiple highest seqnums unless you
1345         # pass force=True
1346         d.addCallback(lambda res:
1347                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1348                                           1:4,3:4,5:4,7:4,9:4}))
1349         d.addCallback(self.copy_shares)
1350         d.addCallback(lambda res: self._fn.check(Monitor()))
1351         def _try_repair(check_results):
1352             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1353             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1354                                  self._fn.repair, check_results)
1355             d2.addCallback(self.copy_shares)
1356             d2.addCallback(self.failIfSharesChanged)
1357             d2.addCallback(lambda res: check_results)
1358             return d2
1359         d.addCallback(_try_repair)
1360         d.addCallback(lambda check_results:
1361                       self._fn.repair(check_results, force=True))
1362         # this should give us 10 shares of the highest roothash
1363         def _check_repair_results(rres):
1364             pass # TODO
1365         d.addCallback(_check_repair_results)
1366         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1367         def _check_smap(smap):
1368             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1369             self.failIf(smap.unrecoverable_versions())
1370             # now, which should have won?
1371             roothash_s4a = self.get_roothash_for(3)
1372             roothash_s4b = self.get_roothash_for(4)
1373             if roothash_s4b > roothash_s4a:
1374                 expected_contents = self.CONTENTS[4]
1375             else:
1376                 expected_contents = self.CONTENTS[3]
1377             new_versionid = smap.best_recoverable_version()
1378             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1379             d2 = self._fn.download_version(smap, new_versionid)
1380             d2.addCallback(self.failUnlessEqual, expected_contents)
1381             return d2
1382         d.addCallback(_check_smap)
1383         return d
1384
1385     def test_non_merge(self):
1386         self.old_shares = []
1387         d = self.publish_multiple()
1388         # repair should not refuse a repair that doesn't need to merge. In
1389         # this case, we combine v2 with v3. The repair should ignore v2 and
1390         # copy v3 into a new v5.
1391         d.addCallback(lambda res:
1392                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1393                                           1:3,3:3,5:3,7:3,9:3}))
1394         d.addCallback(lambda res: self._fn.check(Monitor()))
1395         d.addCallback(lambda check_results: self._fn.repair(check_results))
1396         # this should give us 10 shares of v3
1397         def _check_repair_results(rres):
1398             pass # TODO
1399         d.addCallback(_check_repair_results)
1400         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1401         def _check_smap(smap):
1402             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1403             self.failIf(smap.unrecoverable_versions())
1404             # now, which should have won?
1405             roothash_s4a = self.get_roothash_for(3)
1406             expected_contents = self.CONTENTS[3]
1407             new_versionid = smap.best_recoverable_version()
1408             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1409             d2 = self._fn.download_version(smap, new_versionid)
1410             d2.addCallback(self.failUnlessEqual, expected_contents)
1411             return d2
1412         d.addCallback(_check_smap)
1413         return d
1414
1415     def get_roothash_for(self, index):
1416         # return the roothash for the first share we see in the saved set
1417         shares = self._copied_shares[index]
1418         for peerid in shares:
1419             for shnum in shares[peerid]:
1420                 share = shares[peerid][shnum]
1421                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1422                           unpack_header(share)
1423                 return root_hash
1424
1425     def test_check_and_repair_readcap(self):
1426         # we can't currently repair from a mutable readcap: #625
1427         self.old_shares = []
1428         d = self.publish_one()
1429         d.addCallback(self.copy_shares)
1430         def _get_readcap(res):
1431             self._fn3 = self._fn.get_readonly()
1432             # also delete some shares
1433             for peerid,shares in self._storage._peers.items():
1434                 shares.pop(0, None)
1435         d.addCallback(_get_readcap)
1436         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1437         def _check_results(crr):
1438             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1439             # we should detect the unhealthy, but skip over mutable-readcap
1440             # repairs until #625 is fixed
1441             self.failIf(crr.get_pre_repair_results().is_healthy())
1442             self.failIf(crr.get_repair_attempted())
1443             self.failIf(crr.get_post_repair_results().is_healthy())
1444         d.addCallback(_check_results)
1445         return d
1446
1447 class DevNullDictionary(dict):
1448     def __setitem__(self, key, value):
1449         return
1450
1451 class MultipleEncodings(unittest.TestCase):
1452     def setUp(self):
1453         self.CONTENTS = "New contents go here"
1454         self._storage = FakeStorage()
1455         self._nodemaker = make_nodemaker(self._storage, num_peers=20)
1456         self._storage_broker = self._nodemaker.storage_broker
1457         d = self._nodemaker.create_mutable_file(self.CONTENTS)
1458         def _created(node):
1459             self._fn = node
1460         d.addCallback(_created)
1461         return d
1462
1463     def _encode(self, k, n, data):
1464         # encode 'data' into a peerid->shares dict.
1465
1466         fn = self._fn
1467         # disable the nodecache, since for these tests we explicitly need
1468         # multiple nodes pointing at the same file
1469         self._nodemaker._node_cache = DevNullDictionary()
1470         fn2 = self._nodemaker.create_from_cap(fn.get_uri())
1471         # then we copy over other fields that are normally fetched from the
1472         # existing shares
1473         fn2._pubkey = fn._pubkey
1474         fn2._privkey = fn._privkey
1475         fn2._encprivkey = fn._encprivkey
1476         # and set the encoding parameters to something completely different
1477         fn2._required_shares = k
1478         fn2._total_shares = n
1479
1480         s = self._storage
1481         s._peers = {} # clear existing storage
1482         p2 = Publish(fn2, self._storage_broker, None)
1483         d = p2.publish(data)
1484         def _published(res):
1485             shares = s._peers
1486             s._peers = {}
1487             return shares
1488         d.addCallback(_published)
1489         return d
1490
1491     def make_servermap(self, mode=MODE_READ, oldmap=None):
1492         if oldmap is None:
1493             oldmap = ServerMap()
1494         smu = ServermapUpdater(self._fn, self._storage_broker, Monitor(),
1495                                oldmap, mode)
1496         d = smu.update()
1497         return d
1498
1499     def test_multiple_encodings(self):
1500         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1501         # then mix up the shares, to make sure that download survives seeing
1502         # a variety of encodings. This is actually kind of tricky to set up.
1503
1504         contents1 = "Contents for encoding 1 (3-of-10) go here"
1505         contents2 = "Contents for encoding 2 (4-of-9) go here"
1506         contents3 = "Contents for encoding 3 (4-of-7) go here"
1507
1508         # we make a retrieval object that doesn't know what encoding
1509         # parameters to use
1510         fn3 = self._nodemaker.create_from_cap(self._fn.get_uri())
1511
1512         # now we upload a file through fn1, and grab its shares
1513         d = self._encode(3, 10, contents1)
1514         def _encoded_1(shares):
1515             self._shares1 = shares
1516         d.addCallback(_encoded_1)
1517         d.addCallback(lambda res: self._encode(4, 9, contents2))
1518         def _encoded_2(shares):
1519             self._shares2 = shares
1520         d.addCallback(_encoded_2)
1521         d.addCallback(lambda res: self._encode(4, 7, contents3))
1522         def _encoded_3(shares):
1523             self._shares3 = shares
1524         d.addCallback(_encoded_3)
1525
1526         def _merge(res):
1527             log.msg("merging sharelists")
1528             # we merge the shares from the two sets, leaving each shnum in
1529             # its original location, but using a share from set1 or set2
1530             # according to the following sequence:
1531             #
1532             #  4-of-9  a  s2
1533             #  4-of-9  b  s2
1534             #  4-of-7  c   s3
1535             #  4-of-9  d  s2
1536             #  3-of-9  e s1
1537             #  3-of-9  f s1
1538             #  3-of-9  g s1
1539             #  4-of-9  h  s2
1540             #
1541             # so that neither form can be recovered until fetch [f], at which
1542             # point version-s1 (the 3-of-10 form) should be recoverable. If
1543             # the implementation latches on to the first version it sees,
1544             # then s2 will be recoverable at fetch [g].
1545
1546             # Later, when we implement code that handles multiple versions,
1547             # we can use this framework to assert that all recoverable
1548             # versions are retrieved, and test that 'epsilon' does its job
1549
1550             places = [2, 2, 3, 2, 1, 1, 1, 2]
1551
1552             sharemap = {}
1553             sb = self._storage_broker
1554
1555             for peerid in sorted(sb.get_all_serverids()):
1556                 peerid_s = shortnodeid_b2a(peerid)
1557                 for shnum in self._shares1.get(peerid, {}):
1558                     if shnum < len(places):
1559                         which = places[shnum]
1560                     else:
1561                         which = "x"
1562                     self._storage._peers[peerid] = peers = {}
1563                     in_1 = shnum in self._shares1[peerid]
1564                     in_2 = shnum in self._shares2.get(peerid, {})
1565                     in_3 = shnum in self._shares3.get(peerid, {})
1566                     #print peerid_s, shnum, which, in_1, in_2, in_3
1567                     if which == 1:
1568                         if in_1:
1569                             peers[shnum] = self._shares1[peerid][shnum]
1570                             sharemap[shnum] = peerid
1571                     elif which == 2:
1572                         if in_2:
1573                             peers[shnum] = self._shares2[peerid][shnum]
1574                             sharemap[shnum] = peerid
1575                     elif which == 3:
1576                         if in_3:
1577                             peers[shnum] = self._shares3[peerid][shnum]
1578                             sharemap[shnum] = peerid
1579
1580             # we don't bother placing any other shares
1581             # now sort the sequence so that share 0 is returned first
1582             new_sequence = [sharemap[shnum]
1583                             for shnum in sorted(sharemap.keys())]
1584             self._storage._sequence = new_sequence
1585             log.msg("merge done")
1586         d.addCallback(_merge)
1587         d.addCallback(lambda res: fn3.download_best_version())
1588         def _retrieved(new_contents):
1589             # the current specified behavior is "first version recoverable"
1590             self.failUnlessEqual(new_contents, contents1)
1591         d.addCallback(_retrieved)
1592         return d
1593
1594
1595 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1596
1597     def setUp(self):
1598         return self.publish_multiple()
1599
1600     def test_multiple_versions(self):
1601         # if we see a mix of versions in the grid, download_best_version
1602         # should get the latest one
1603         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1604         d = self._fn.download_best_version()
1605         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1606         # and the checker should report problems
1607         d.addCallback(lambda res: self._fn.check(Monitor()))
1608         d.addCallback(self.check_bad, "test_multiple_versions")
1609
1610         # but if everything is at version 2, that's what we should download
1611         d.addCallback(lambda res:
1612                       self._set_versions(dict([(i,2) for i in range(10)])))
1613         d.addCallback(lambda res: self._fn.download_best_version())
1614         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1615         # if exactly one share is at version 3, we should still get v2
1616         d.addCallback(lambda res:
1617                       self._set_versions({0:3}))
1618         d.addCallback(lambda res: self._fn.download_best_version())
1619         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1620         # but the servermap should see the unrecoverable version. This
1621         # depends upon the single newer share being queried early.
1622         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1623         def _check_smap(smap):
1624             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1625             newer = smap.unrecoverable_newer_versions()
1626             self.failUnlessEqual(len(newer), 1)
1627             verinfo, health = newer.items()[0]
1628             self.failUnlessEqual(verinfo[0], 4)
1629             self.failUnlessEqual(health, (1,3))
1630             self.failIf(smap.needs_merge())
1631         d.addCallback(_check_smap)
1632         # if we have a mix of two parallel versions (s4a and s4b), we could
1633         # recover either
1634         d.addCallback(lambda res:
1635                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1636                                           1:4,3:4,5:4,7:4,9:4}))
1637         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1638         def _check_smap_mixed(smap):
1639             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1640             newer = smap.unrecoverable_newer_versions()
1641             self.failUnlessEqual(len(newer), 0)
1642             self.failUnless(smap.needs_merge())
1643         d.addCallback(_check_smap_mixed)
1644         d.addCallback(lambda res: self._fn.download_best_version())
1645         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1646                                                   res == self.CONTENTS[4]))
1647         return d
1648
1649     def test_replace(self):
1650         # if we see a mix of versions in the grid, we should be able to
1651         # replace them all with a newer version
1652
1653         # if exactly one share is at version 3, we should download (and
1654         # replace) v2, and the result should be v4. Note that the index we
1655         # give to _set_versions is different than the sequence number.
1656         target = dict([(i,2) for i in range(10)]) # seqnum3
1657         target[0] = 3 # seqnum4
1658         self._set_versions(target)
1659
1660         def _modify(oldversion, servermap, first_time):
1661             return oldversion + " modified"
1662         d = self._fn.modify(_modify)
1663         d.addCallback(lambda res: self._fn.download_best_version())
1664         expected = self.CONTENTS[2] + " modified"
1665         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1666         # and the servermap should indicate that the outlier was replaced too
1667         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1668         def _check_smap(smap):
1669             self.failUnlessEqual(smap.highest_seqnum(), 5)
1670             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1671             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1672         d.addCallback(_check_smap)
1673         return d
1674
1675
1676 class Utils(unittest.TestCase):
1677     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1678         # we compare this against sets of integers
1679         x = set(range(x_start, x_start+x_length))
1680         y = set(range(y_start, y_start+y_length))
1681         should_be_inside = x.issubset(y)
1682         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1683                                                          y_start, y_length),
1684                              str((x_start, x_length, y_start, y_length)))
1685
1686     def test_cache_inside(self):
1687         c = ResponseCache()
1688         x_start = 10
1689         x_length = 5
1690         for y_start in range(8, 17):
1691             for y_length in range(8):
1692                 self._do_inside(c, x_start, x_length, y_start, y_length)
1693
1694     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1695         # we compare this against sets of integers
1696         x = set(range(x_start, x_start+x_length))
1697         y = set(range(y_start, y_start+y_length))
1698         overlap = bool(x.intersection(y))
1699         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1700                                                       y_start, y_length),
1701                              str((x_start, x_length, y_start, y_length)))
1702
1703     def test_cache_overlap(self):
1704         c = ResponseCache()
1705         x_start = 10
1706         x_length = 5
1707         for y_start in range(8, 17):
1708             for y_length in range(8):
1709                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1710
1711     def test_cache(self):
1712         c = ResponseCache()
1713         # xdata = base62.b2a(os.urandom(100))[:100]
1714         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1715         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1716         nope = (None, None)
1717         c.add("v1", 1, 0, xdata, "time0")
1718         c.add("v1", 1, 2000, ydata, "time1")
1719         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1720         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1721         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1722         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1723         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1724         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1725         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1726         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1727         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1728         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1729         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1730         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1731         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1732         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1733         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1734         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1735         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1736         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1737
1738         # optional: join fragments
1739         c = ResponseCache()
1740         c.add("v1", 1, 0, xdata[:10], "time0")
1741         c.add("v1", 1, 10, xdata[10:20], "time1")
1742         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1743
1744 class Exceptions(unittest.TestCase):
1745     def test_repr(self):
1746         nmde = NeedMoreDataError(100, 50, 100)
1747         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1748         ucwe = UncoordinatedWriteError()
1749         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1750
1751 class SameKeyGenerator:
1752     def __init__(self, pubkey, privkey):
1753         self.pubkey = pubkey
1754         self.privkey = privkey
1755     def generate(self, keysize=None):
1756         return defer.succeed( (self.pubkey, self.privkey) )
1757
1758 class FirstServerGetsKilled:
1759     done = False
1760     def notify(self, retval, wrapper, methname):
1761         if not self.done:
1762             wrapper.broken = True
1763             self.done = True
1764         return retval
1765
1766 class FirstServerGetsDeleted:
1767     def __init__(self):
1768         self.done = False
1769         self.silenced = None
1770     def notify(self, retval, wrapper, methname):
1771         if not self.done:
1772             # this query will work, but later queries should think the share
1773             # has been deleted
1774             self.done = True
1775             self.silenced = wrapper
1776             return retval
1777         if wrapper == self.silenced:
1778             assert methname == "slot_testv_and_readv_and_writev"
1779             return (True, {})
1780         return retval
1781
1782 class Problems(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
1783     def test_publish_surprise(self):
1784         self.basedir = "mutable/Problems/test_publish_surprise"
1785         self.set_up_grid()
1786         nm = self.g.clients[0].nodemaker
1787         d = nm.create_mutable_file("contents 1")
1788         def _created(n):
1789             d = defer.succeed(None)
1790             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1791             def _got_smap1(smap):
1792                 # stash the old state of the file
1793                 self.old_map = smap
1794             d.addCallback(_got_smap1)
1795             # then modify the file, leaving the old map untouched
1796             d.addCallback(lambda res: log.msg("starting winning write"))
1797             d.addCallback(lambda res: n.overwrite("contents 2"))
1798             # now attempt to modify the file with the old servermap. This
1799             # will look just like an uncoordinated write, in which every
1800             # single share got updated between our mapupdate and our publish
1801             d.addCallback(lambda res: log.msg("starting doomed write"))
1802             d.addCallback(lambda res:
1803                           self.shouldFail(UncoordinatedWriteError,
1804                                           "test_publish_surprise", None,
1805                                           n.upload,
1806                                           "contents 2a", self.old_map))
1807             return d
1808         d.addCallback(_created)
1809         return d
1810
1811     def test_retrieve_surprise(self):
1812         self.basedir = "mutable/Problems/test_retrieve_surprise"
1813         self.set_up_grid()
1814         nm = self.g.clients[0].nodemaker
1815         d = nm.create_mutable_file("contents 1")
1816         def _created(n):
1817             d = defer.succeed(None)
1818             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1819             def _got_smap1(smap):
1820                 # stash the old state of the file
1821                 self.old_map = smap
1822             d.addCallback(_got_smap1)
1823             # then modify the file, leaving the old map untouched
1824             d.addCallback(lambda res: log.msg("starting winning write"))
1825             d.addCallback(lambda res: n.overwrite("contents 2"))
1826             # now attempt to retrieve the old version with the old servermap.
1827             # This will look like someone has changed the file since we
1828             # updated the servermap.
1829             d.addCallback(lambda res: n._cache._clear())
1830             d.addCallback(lambda res: log.msg("starting doomed read"))
1831             d.addCallback(lambda res:
1832                           self.shouldFail(NotEnoughSharesError,
1833                                           "test_retrieve_surprise",
1834                                           "ran out of peers: have 0 shares (k=3)",
1835                                           n.download_version,
1836                                           self.old_map,
1837                                           self.old_map.best_recoverable_version(),
1838                                           ))
1839             return d
1840         d.addCallback(_created)
1841         return d
1842
1843     def test_unexpected_shares(self):
1844         # upload the file, take a servermap, shut down one of the servers,
1845         # upload it again (causing shares to appear on a new server), then
1846         # upload using the old servermap. The last upload should fail with an
1847         # UncoordinatedWriteError, because of the shares that didn't appear
1848         # in the servermap.
1849         self.basedir = "mutable/Problems/test_unexpected_shares"
1850         self.set_up_grid()
1851         nm = self.g.clients[0].nodemaker
1852         d = nm.create_mutable_file("contents 1")
1853         def _created(n):
1854             d = defer.succeed(None)
1855             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1856             def _got_smap1(smap):
1857                 # stash the old state of the file
1858                 self.old_map = smap
1859                 # now shut down one of the servers
1860                 peer0 = list(smap.make_sharemap()[0])[0]
1861                 self.g.remove_server(peer0)
1862                 # then modify the file, leaving the old map untouched
1863                 log.msg("starting winning write")
1864                 return n.overwrite("contents 2")
1865             d.addCallback(_got_smap1)
1866             # now attempt to modify the file with the old servermap. This
1867             # will look just like an uncoordinated write, in which every
1868             # single share got updated between our mapupdate and our publish
1869             d.addCallback(lambda res: log.msg("starting doomed write"))
1870             d.addCallback(lambda res:
1871                           self.shouldFail(UncoordinatedWriteError,
1872                                           "test_surprise", None,
1873                                           n.upload,
1874                                           "contents 2a", self.old_map))
1875             return d
1876         d.addCallback(_created)
1877         return d
1878
1879     def test_bad_server(self):
1880         # Break one server, then create the file: the initial publish should
1881         # complete with an alternate server. Breaking a second server should
1882         # not prevent an update from succeeding either.
1883         self.basedir = "mutable/Problems/test_bad_server"
1884         self.set_up_grid()
1885         nm = self.g.clients[0].nodemaker
1886
1887         # to make sure that one of the initial peers is broken, we have to
1888         # get creative. We create an RSA key and compute its storage-index.
1889         # Then we make a KeyGenerator that always returns that one key, and
1890         # use it to create the mutable file. This will get easier when we can
1891         # use #467 static-server-selection to disable permutation and force
1892         # the choice of server for share[0].
1893
1894         d = nm.key_generator.generate(522)
1895         def _got_key( (pubkey, privkey) ):
1896             nm.key_generator = SameKeyGenerator(pubkey, privkey)
1897             pubkey_s = pubkey.serialize()
1898             privkey_s = privkey.serialize()
1899             u = uri.WriteableSSKFileURI(ssk_writekey_hash(privkey_s),
1900                                         ssk_pubkey_fingerprint_hash(pubkey_s))
1901             self._storage_index = u.storage_index
1902         d.addCallback(_got_key)
1903         def _break_peer0(res):
1904             si = self._storage_index
1905             peerlist = nm.storage_broker.get_servers_for_index(si)
1906             peerid0, connection0 = peerlist[0]
1907             peerid1, connection1 = peerlist[1]
1908             connection0.broken = True
1909             self.connection1 = connection1
1910         d.addCallback(_break_peer0)
1911         # now "create" the file, using the pre-established key, and let the
1912         # initial publish finally happen
1913         d.addCallback(lambda res: nm.create_mutable_file("contents 1"))
1914         # that ought to work
1915         def _got_node(n):
1916             d = n.download_best_version()
1917             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1918             # now break the second peer
1919             def _break_peer1(res):
1920                 self.connection1.broken = True
1921             d.addCallback(_break_peer1)
1922             d.addCallback(lambda res: n.overwrite("contents 2"))
1923             # that ought to work too
1924             d.addCallback(lambda res: n.download_best_version())
1925             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1926             def _explain_error(f):
1927                 print f
1928                 if f.check(NotEnoughServersError):
1929                     print "first_error:", f.value.first_error
1930                 return f
1931             d.addErrback(_explain_error)
1932             return d
1933         d.addCallback(_got_node)
1934         return d
1935
1936     def test_bad_server_overlap(self):
1937         # like test_bad_server, but with no extra unused servers to fall back
1938         # upon. This means that we must re-use a server which we've already
1939         # used. If we don't remember the fact that we sent them one share
1940         # already, we'll mistakenly think we're experiencing an
1941         # UncoordinatedWriteError.
1942
1943         # Break one server, then create the file: the initial publish should
1944         # complete with an alternate server. Breaking a second server should
1945         # not prevent an update from succeeding either.
1946         self.basedir = "mutable/Problems/test_bad_server_overlap"
1947         self.set_up_grid()
1948         nm = self.g.clients[0].nodemaker
1949         sb = nm.storage_broker
1950
1951         peerids = [serverid for (serverid,ss) in sb.get_all_servers()]
1952         self.g.break_server(peerids[0])
1953
1954         d = nm.create_mutable_file("contents 1")
1955         def _created(n):
1956             d = n.download_best_version()
1957             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1958             # now break one of the remaining servers
1959             def _break_second_server(res):
1960                 self.g.break_server(peerids[1])
1961             d.addCallback(_break_second_server)
1962             d.addCallback(lambda res: n.overwrite("contents 2"))
1963             # that ought to work too
1964             d.addCallback(lambda res: n.download_best_version())
1965             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1966             return d
1967         d.addCallback(_created)
1968         return d
1969
1970     def test_publish_all_servers_bad(self):
1971         # Break all servers: the publish should fail
1972         self.basedir = "mutable/Problems/test_publish_all_servers_bad"
1973         self.set_up_grid()
1974         nm = self.g.clients[0].nodemaker
1975         for (serverid,ss) in nm.storage_broker.get_all_servers():
1976             ss.broken = True
1977
1978         d = self.shouldFail(NotEnoughServersError,
1979                             "test_publish_all_servers_bad",
1980                             "Ran out of non-bad servers",
1981                             nm.create_mutable_file, "contents")
1982         return d
1983
1984     def test_publish_no_servers(self):
1985         # no servers at all: the publish should fail
1986         self.basedir = "mutable/Problems/test_publish_no_servers"
1987         self.set_up_grid(num_servers=0)
1988         nm = self.g.clients[0].nodemaker
1989
1990         d = self.shouldFail(NotEnoughServersError,
1991                             "test_publish_no_servers",
1992                             "Ran out of non-bad servers",
1993                             nm.create_mutable_file, "contents")
1994         return d
1995     test_publish_no_servers.timeout = 30
1996
1997
1998     def test_privkey_query_error(self):
1999         # when a servermap is updated with MODE_WRITE, it tries to get the
2000         # privkey. Something might go wrong during this query attempt.
2001         # Exercise the code in _privkey_query_failed which tries to handle
2002         # such an error.
2003         self.basedir = "mutable/Problems/test_privkey_query_error"
2004         self.set_up_grid(num_servers=20)
2005         nm = self.g.clients[0].nodemaker
2006         nm._node_cache = DevNullDictionary() # disable the nodecache
2007
2008         # we need some contents that are large enough to push the privkey out
2009         # of the early part of the file
2010         LARGE = "These are Larger contents" * 2000 # about 50KB
2011         d = nm.create_mutable_file(LARGE)
2012         def _created(n):
2013             self.uri = n.get_uri()
2014             self.n2 = nm.create_from_cap(self.uri)
2015
2016             # When a mapupdate is performed on a node that doesn't yet know
2017             # the privkey, a short read is sent to a batch of servers, to get
2018             # the verinfo and (hopefully, if the file is short enough) the
2019             # encprivkey. Our file is too large to let this first read
2020             # contain the encprivkey. Each non-encprivkey-bearing response
2021             # that arrives (until the node gets the encprivkey) will trigger
2022             # a second read to specifically read the encprivkey.
2023             #
2024             # So, to exercise this case:
2025             #  1. notice which server gets a read() call first
2026             #  2. tell that server to start throwing errors
2027             killer = FirstServerGetsKilled()
2028             for (serverid,ss) in nm.storage_broker.get_all_servers():
2029                 ss.post_call_notifier = killer.notify
2030         d.addCallback(_created)
2031
2032         # now we update a servermap from a new node (which doesn't have the
2033         # privkey yet, forcing it to use a separate privkey query). Note that
2034         # the map-update will succeed, since we'll just get a copy from one
2035         # of the other shares.
2036         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2037
2038         return d
2039
2040     def test_privkey_query_missing(self):
2041         # like test_privkey_query_error, but the shares are deleted by the
2042         # second query, instead of raising an exception.
2043         self.basedir = "mutable/Problems/test_privkey_query_missing"
2044         self.set_up_grid(num_servers=20)
2045         nm = self.g.clients[0].nodemaker
2046         LARGE = "These are Larger contents" * 2000 # about 50KB
2047         nm._node_cache = DevNullDictionary() # disable the nodecache
2048
2049         d = nm.create_mutable_file(LARGE)
2050         def _created(n):
2051             self.uri = n.get_uri()
2052             self.n2 = nm.create_from_cap(self.uri)
2053             deleter = FirstServerGetsDeleted()
2054             for (serverid,ss) in nm.storage_broker.get_all_servers():
2055                 ss.post_call_notifier = deleter.notify
2056         d.addCallback(_created)
2057         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2058         return d