]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
mutable repairer: skip repair of readcaps instead of throwing an exception.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri
8 from allmydata.storage.server import StorageServer
9 from allmydata.immutable import download
10 from allmydata.util import base32, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      NotEnoughSharesError, IRepairResults, ICheckAndRepairResults
16 from allmydata.monitor import Monitor
17 from allmydata.test.common import ShouldFailMixin
18 from foolscap.api import eventually, fireEventually
19 from foolscap.logging import log
20 from allmydata.storage_client import StorageFarmBroker
21
22 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
23 from allmydata.mutable.common import ResponseCache, \
24      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
25      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
26      NotEnoughServersError, CorruptShareError
27 from allmydata.mutable.retrieve import Retrieve
28 from allmydata.mutable.publish import Publish
29 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
30 from allmydata.mutable.layout import unpack_header, unpack_share
31 from allmydata.mutable.repairer import MustForceRepairError
32
33 import common_util as testutil
34
35 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
36 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
37 # of this.
38
39 class FastMutableFileNode(MutableFileNode):
40     SIGNATURE_KEY_SIZE = 522
41
42 # this "FakeStorage" exists to put the share data in RAM and avoid using real
43 # network connections, both to speed up the tests and to reduce the amount of
44 # non-mutable.py code being exercised.
45
46 class FakeStorage:
47     # this class replaces the collection of storage servers, allowing the
48     # tests to examine and manipulate the published shares. It also lets us
49     # control the order in which read queries are answered, to exercise more
50     # of the error-handling code in Retrieve .
51     #
52     # Note that we ignore the storage index: this FakeStorage instance can
53     # only be used for a single storage index.
54
55
56     def __init__(self):
57         self._peers = {}
58         # _sequence is used to cause the responses to occur in a specific
59         # order. If it is in use, then we will defer queries instead of
60         # answering them right away, accumulating the Deferreds in a dict. We
61         # don't know exactly how many queries we'll get, so exactly one
62         # second after the first query arrives, we will release them all (in
63         # order).
64         self._sequence = None
65         self._pending = {}
66         self._pending_timer = None
67         self._special_answers = {}
68
69     def read(self, peerid, storage_index):
70         shares = self._peers.get(peerid, {})
71         if self._special_answers.get(peerid, []):
72             mode = self._special_answers[peerid].pop(0)
73             if mode == "fail":
74                 shares = failure.Failure(IntentionalError())
75             elif mode == "none":
76                 shares = {}
77             elif mode == "normal":
78                 pass
79         if self._sequence is None:
80             return defer.succeed(shares)
81         d = defer.Deferred()
82         if not self._pending:
83             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
84         self._pending[peerid] = (d, shares)
85         return d
86
87     def _fire_readers(self):
88         self._pending_timer = None
89         pending = self._pending
90         self._pending = {}
91         extra = []
92         for peerid in self._sequence:
93             if peerid in pending:
94                 d, shares = pending.pop(peerid)
95                 eventually(d.callback, shares)
96         for (d, shares) in pending.values():
97             eventually(d.callback, shares)
98
99     def write(self, peerid, storage_index, shnum, offset, data):
100         if peerid not in self._peers:
101             self._peers[peerid] = {}
102         shares = self._peers[peerid]
103         f = StringIO()
104         f.write(shares.get(shnum, ""))
105         f.seek(offset)
106         f.write(data)
107         shares[shnum] = f.getvalue()
108
109
110 class FakeStorageServer:
111     def __init__(self, peerid, storage):
112         self.peerid = peerid
113         self.storage = storage
114         self.queries = 0
115     def callRemote(self, methname, *args, **kwargs):
116         def _call():
117             meth = getattr(self, methname)
118             return meth(*args, **kwargs)
119         d = fireEventually()
120         d.addCallback(lambda res: _call())
121         return d
122     def callRemoteOnly(self, methname, *args, **kwargs):
123         d = self.callRemote(methname, *args, **kwargs)
124         d.addBoth(lambda ignore: None)
125         pass
126
127     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
128         pass
129
130     def slot_readv(self, storage_index, shnums, readv):
131         d = self.storage.read(self.peerid, storage_index)
132         def _read(shares):
133             response = {}
134             for shnum in shares:
135                 if shnums and shnum not in shnums:
136                     continue
137                 vector = response[shnum] = []
138                 for (offset, length) in readv:
139                     assert isinstance(offset, (int, long)), offset
140                     assert isinstance(length, (int, long)), length
141                     vector.append(shares[shnum][offset:offset+length])
142             return response
143         d.addCallback(_read)
144         return d
145
146     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
147                                         tw_vectors, read_vector):
148         # always-pass: parrot the test vectors back to them.
149         readv = {}
150         for shnum, (testv, writev, new_length) in tw_vectors.items():
151             for (offset, length, op, specimen) in testv:
152                 assert op in ("le", "eq", "ge")
153             # TODO: this isn't right, the read is controlled by read_vector,
154             # not by testv
155             readv[shnum] = [ specimen
156                              for (offset, length, op, specimen)
157                              in testv ]
158             for (offset, data) in writev:
159                 self.storage.write(self.peerid, storage_index, shnum,
160                                    offset, data)
161         answer = (True, readv)
162         return fireEventually(answer)
163
164
165 # our "FakeClient" has just enough functionality of the real Client to let
166 # the tests run.
167
168 class FakeClient:
169     mutable_file_node_class = FastMutableFileNode
170
171     def __init__(self, num_peers=10):
172         self._storage = FakeStorage()
173         self._num_peers = num_peers
174         peerids = [tagged_hash("peerid", "%d" % i)[:20]
175                    for i in range(self._num_peers)]
176         self.nodeid = "fakenodeid"
177         self.storage_broker = StorageFarmBroker(None, True)
178         for peerid in peerids:
179             fss = FakeStorageServer(peerid, self._storage)
180             self.storage_broker.test_add_server(peerid, fss)
181
182     def get_storage_broker(self):
183         return self.storage_broker
184     def debug_break_connection(self, peerid):
185         self.storage_broker.test_servers[peerid].broken = True
186     def debug_remove_connection(self, peerid):
187         self.storage_broker.test_servers.pop(peerid)
188     def debug_get_connection(self, peerid):
189         return self.storage_broker.test_servers[peerid]
190
191     def get_encoding_parameters(self):
192         return {"k": 3, "n": 10}
193
194     def log(self, msg, **kw):
195         return log.msg(msg, **kw)
196
197     def get_renewal_secret(self):
198         return "I hereby permit you to renew my files"
199     def get_cancel_secret(self):
200         return "I hereby permit you to cancel my leases"
201
202     def create_mutable_file(self, contents=""):
203         n = self.mutable_file_node_class(self)
204         d = n.create(contents)
205         d.addCallback(lambda res: n)
206         return d
207
208     def get_history(self):
209         return None
210
211     def create_node_from_uri(self, u):
212         u = IURI(u)
213         assert IMutableFileURI.providedBy(u), u
214         res = self.mutable_file_node_class(self).init_from_uri(u)
215         return res
216
217     def upload(self, uploadable):
218         assert IUploadable.providedBy(uploadable)
219         d = uploadable.get_size()
220         d.addCallback(lambda length: uploadable.read(length))
221         #d.addCallback(self.create_mutable_file)
222         def _got_data(datav):
223             data = "".join(datav)
224             #newnode = FastMutableFileNode(self)
225             return uri.LiteralFileURI(data)
226         d.addCallback(_got_data)
227         return d
228
229
230 def flip_bit(original, byte_offset):
231     return (original[:byte_offset] +
232             chr(ord(original[byte_offset]) ^ 0x01) +
233             original[byte_offset+1:])
234
235 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
236     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
237     # list of shnums to corrupt.
238     for peerid in s._peers:
239         shares = s._peers[peerid]
240         for shnum in shares:
241             if (shnums_to_corrupt is not None
242                 and shnum not in shnums_to_corrupt):
243                 continue
244             data = shares[shnum]
245             (version,
246              seqnum,
247              root_hash,
248              IV,
249              k, N, segsize, datalen,
250              o) = unpack_header(data)
251             if isinstance(offset, tuple):
252                 offset1, offset2 = offset
253             else:
254                 offset1 = offset
255                 offset2 = 0
256             if offset1 == "pubkey":
257                 real_offset = 107
258             elif offset1 in o:
259                 real_offset = o[offset1]
260             else:
261                 real_offset = offset1
262             real_offset = int(real_offset) + offset2 + offset_offset
263             assert isinstance(real_offset, int), offset
264             shares[shnum] = flip_bit(data, real_offset)
265     return res
266
267 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
268     # this used to be in Publish, but we removed the limit. Some of
269     # these tests test whether the new code correctly allows files
270     # larger than the limit.
271     OLD_MAX_SEGMENT_SIZE = 3500000
272     def setUp(self):
273         self.client = FakeClient()
274
275     def test_create(self):
276         d = self.client.create_mutable_file()
277         def _created(n):
278             self.failUnless(isinstance(n, FastMutableFileNode))
279             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
280             sb = self.client.get_storage_broker()
281             peer0 = sorted(sb.get_all_serverids())[0]
282             shnums = self.client._storage._peers[peer0].keys()
283             self.failUnlessEqual(len(shnums), 1)
284         d.addCallback(_created)
285         return d
286
287     def test_serialize(self):
288         n = MutableFileNode(self.client)
289         calls = []
290         def _callback(*args, **kwargs):
291             self.failUnlessEqual(args, (4,) )
292             self.failUnlessEqual(kwargs, {"foo": 5})
293             calls.append(1)
294             return 6
295         d = n._do_serialized(_callback, 4, foo=5)
296         def _check_callback(res):
297             self.failUnlessEqual(res, 6)
298             self.failUnlessEqual(calls, [1])
299         d.addCallback(_check_callback)
300
301         def _errback():
302             raise ValueError("heya")
303         d.addCallback(lambda res:
304                       self.shouldFail(ValueError, "_check_errback", "heya",
305                                       n._do_serialized, _errback))
306         return d
307
308     def test_upload_and_download(self):
309         d = self.client.create_mutable_file()
310         def _created(n):
311             d = defer.succeed(None)
312             d.addCallback(lambda res: n.get_servermap(MODE_READ))
313             d.addCallback(lambda smap: smap.dump(StringIO()))
314             d.addCallback(lambda sio:
315                           self.failUnless("3-of-10" in sio.getvalue()))
316             d.addCallback(lambda res: n.overwrite("contents 1"))
317             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
318             d.addCallback(lambda res: n.download_best_version())
319             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
320             d.addCallback(lambda res: n.get_size_of_best_version())
321             d.addCallback(lambda size:
322                           self.failUnlessEqual(size, len("contents 1")))
323             d.addCallback(lambda res: n.overwrite("contents 2"))
324             d.addCallback(lambda res: n.download_best_version())
325             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
326             d.addCallback(lambda res: n.download(download.Data()))
327             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
328             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
329             d.addCallback(lambda smap: n.upload("contents 3", smap))
330             d.addCallback(lambda res: n.download_best_version())
331             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
332             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
333             d.addCallback(lambda smap:
334                           n.download_version(smap,
335                                              smap.best_recoverable_version()))
336             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
337             # test a file that is large enough to overcome the
338             # mapupdate-to-retrieve data caching (i.e. make the shares larger
339             # than the default readsize, which is 2000 bytes). A 15kB file
340             # will have 5kB shares.
341             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
342             d.addCallback(lambda res: n.download_best_version())
343             d.addCallback(lambda res:
344                           self.failUnlessEqual(res, "large size file" * 1000))
345             return d
346         d.addCallback(_created)
347         return d
348
349     def test_create_with_initial_contents(self):
350         d = self.client.create_mutable_file("contents 1")
351         def _created(n):
352             d = n.download_best_version()
353             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
354             d.addCallback(lambda res: n.overwrite("contents 2"))
355             d.addCallback(lambda res: n.download_best_version())
356             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
357             return d
358         d.addCallback(_created)
359         return d
360
361     def test_create_with_too_large_contents(self):
362         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
363         d = self.client.create_mutable_file(BIG)
364         def _created(n):
365             d = n.overwrite(BIG)
366             return d
367         d.addCallback(_created)
368         return d
369
370     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
371         d = n.get_servermap(MODE_READ)
372         d.addCallback(lambda servermap: servermap.best_recoverable_version())
373         d.addCallback(lambda verinfo:
374                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
375         return d
376
377     def test_modify(self):
378         def _modifier(old_contents, servermap, first_time):
379             return old_contents + "line2"
380         def _non_modifier(old_contents, servermap, first_time):
381             return old_contents
382         def _none_modifier(old_contents, servermap, first_time):
383             return None
384         def _error_modifier(old_contents, servermap, first_time):
385             raise ValueError("oops")
386         def _toobig_modifier(old_contents, servermap, first_time):
387             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
388         calls = []
389         def _ucw_error_modifier(old_contents, servermap, first_time):
390             # simulate an UncoordinatedWriteError once
391             calls.append(1)
392             if len(calls) <= 1:
393                 raise UncoordinatedWriteError("simulated")
394             return old_contents + "line3"
395         def _ucw_error_non_modifier(old_contents, servermap, first_time):
396             # simulate an UncoordinatedWriteError once, and don't actually
397             # modify the contents on subsequent invocations
398             calls.append(1)
399             if len(calls) <= 1:
400                 raise UncoordinatedWriteError("simulated")
401             return old_contents
402
403         d = self.client.create_mutable_file("line1")
404         def _created(n):
405             d = n.modify(_modifier)
406             d.addCallback(lambda res: n.download_best_version())
407             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
409
410             d.addCallback(lambda res: n.modify(_non_modifier))
411             d.addCallback(lambda res: n.download_best_version())
412             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
413             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
414
415             d.addCallback(lambda res: n.modify(_none_modifier))
416             d.addCallback(lambda res: n.download_best_version())
417             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
418             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
419
420             d.addCallback(lambda res:
421                           self.shouldFail(ValueError, "error_modifier", None,
422                                           n.modify, _error_modifier))
423             d.addCallback(lambda res: n.download_best_version())
424             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
425             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
426
427
428             d.addCallback(lambda res: n.download_best_version())
429             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
430             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
431
432             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
433             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
434             d.addCallback(lambda res: n.download_best_version())
435             d.addCallback(lambda res: self.failUnlessEqual(res,
436                                                            "line1line2line3"))
437             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
438
439             def _reset_ucw_error_modifier(res):
440                 calls[:] = []
441                 return res
442             d.addCallback(_reset_ucw_error_modifier)
443
444             # in practice, this n.modify call should publish twice: the first
445             # one gets a UCWE, the second does not. But our test jig (in
446             # which the modifier raises the UCWE) skips over the first one,
447             # so in this test there will be only one publish, and the seqnum
448             # will only be one larger than the previous test, not two (i.e. 4
449             # instead of 5).
450             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
451             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
452             d.addCallback(lambda res: n.download_best_version())
453             d.addCallback(lambda res: self.failUnlessEqual(res,
454                                                            "line1line2line3"))
455             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
456             d.addCallback(lambda res: n.modify(_toobig_modifier))
457             return d
458         d.addCallback(_created)
459         return d
460
461     def test_modify_backoffer(self):
462         def _modifier(old_contents, servermap, first_time):
463             return old_contents + "line2"
464         calls = []
465         def _ucw_error_modifier(old_contents, servermap, first_time):
466             # simulate an UncoordinatedWriteError once
467             calls.append(1)
468             if len(calls) <= 1:
469                 raise UncoordinatedWriteError("simulated")
470             return old_contents + "line3"
471         def _always_ucw_error_modifier(old_contents, servermap, first_time):
472             raise UncoordinatedWriteError("simulated")
473         def _backoff_stopper(node, f):
474             return f
475         def _backoff_pauser(node, f):
476             d = defer.Deferred()
477             reactor.callLater(0.5, d.callback, None)
478             return d
479
480         # the give-up-er will hit its maximum retry count quickly
481         giveuper = BackoffAgent()
482         giveuper._delay = 0.1
483         giveuper.factor = 1
484
485         d = self.client.create_mutable_file("line1")
486         def _created(n):
487             d = n.modify(_modifier)
488             d.addCallback(lambda res: n.download_best_version())
489             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
490             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
491
492             d.addCallback(lambda res:
493                           self.shouldFail(UncoordinatedWriteError,
494                                           "_backoff_stopper", None,
495                                           n.modify, _ucw_error_modifier,
496                                           _backoff_stopper))
497             d.addCallback(lambda res: n.download_best_version())
498             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
499             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
500
501             def _reset_ucw_error_modifier(res):
502                 calls[:] = []
503                 return res
504             d.addCallback(_reset_ucw_error_modifier)
505             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
506                                                _backoff_pauser))
507             d.addCallback(lambda res: n.download_best_version())
508             d.addCallback(lambda res: self.failUnlessEqual(res,
509                                                            "line1line2line3"))
510             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
511
512             d.addCallback(lambda res:
513                           self.shouldFail(UncoordinatedWriteError,
514                                           "giveuper", None,
515                                           n.modify, _always_ucw_error_modifier,
516                                           giveuper.delay))
517             d.addCallback(lambda res: n.download_best_version())
518             d.addCallback(lambda res: self.failUnlessEqual(res,
519                                                            "line1line2line3"))
520             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
521
522             return d
523         d.addCallback(_created)
524         return d
525
526     def test_upload_and_download_full_size_keys(self):
527         self.client.mutable_file_node_class = MutableFileNode
528         d = self.client.create_mutable_file()
529         def _created(n):
530             d = defer.succeed(None)
531             d.addCallback(lambda res: n.get_servermap(MODE_READ))
532             d.addCallback(lambda smap: smap.dump(StringIO()))
533             d.addCallback(lambda sio:
534                           self.failUnless("3-of-10" in sio.getvalue()))
535             d.addCallback(lambda res: n.overwrite("contents 1"))
536             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
537             d.addCallback(lambda res: n.download_best_version())
538             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
539             d.addCallback(lambda res: n.overwrite("contents 2"))
540             d.addCallback(lambda res: n.download_best_version())
541             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
542             d.addCallback(lambda res: n.download(download.Data()))
543             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
544             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
545             d.addCallback(lambda smap: n.upload("contents 3", smap))
546             d.addCallback(lambda res: n.download_best_version())
547             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
548             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
549             d.addCallback(lambda smap:
550                           n.download_version(smap,
551                                              smap.best_recoverable_version()))
552             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
553             return d
554         d.addCallback(_created)
555         return d
556
557
558 class MakeShares(unittest.TestCase):
559     def test_encrypt(self):
560         c = FakeClient()
561         fn = FastMutableFileNode(c)
562         CONTENTS = "some initial contents"
563         d = fn.create(CONTENTS)
564         def _created(res):
565             p = Publish(fn, None)
566             p.salt = "SALT" * 4
567             p.readkey = "\x00" * 16
568             p.newdata = CONTENTS
569             p.required_shares = 3
570             p.total_shares = 10
571             p.setup_encoding_parameters()
572             return p._encrypt_and_encode()
573         d.addCallback(_created)
574         def _done(shares_and_shareids):
575             (shares, share_ids) = shares_and_shareids
576             self.failUnlessEqual(len(shares), 10)
577             for sh in shares:
578                 self.failUnless(isinstance(sh, str))
579                 self.failUnlessEqual(len(sh), 7)
580             self.failUnlessEqual(len(share_ids), 10)
581         d.addCallback(_done)
582         return d
583
584     def test_generate(self):
585         c = FakeClient()
586         fn = FastMutableFileNode(c)
587         CONTENTS = "some initial contents"
588         d = fn.create(CONTENTS)
589         def _created(res):
590             p = Publish(fn, None)
591             self._p = p
592             p.newdata = CONTENTS
593             p.required_shares = 3
594             p.total_shares = 10
595             p.setup_encoding_parameters()
596             p._new_seqnum = 3
597             p.salt = "SALT" * 4
598             # make some fake shares
599             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
600             p._privkey = fn.get_privkey()
601             p._encprivkey = fn.get_encprivkey()
602             p._pubkey = fn.get_pubkey()
603             return p._generate_shares(shares_and_ids)
604         d.addCallback(_created)
605         def _generated(res):
606             p = self._p
607             final_shares = p.shares
608             root_hash = p.root_hash
609             self.failUnlessEqual(len(root_hash), 32)
610             self.failUnless(isinstance(final_shares, dict))
611             self.failUnlessEqual(len(final_shares), 10)
612             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
613             for i,sh in final_shares.items():
614                 self.failUnless(isinstance(sh, str))
615                 # feed the share through the unpacker as a sanity-check
616                 pieces = unpack_share(sh)
617                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
618                  pubkey, signature, share_hash_chain, block_hash_tree,
619                  share_data, enc_privkey) = pieces
620                 self.failUnlessEqual(u_seqnum, 3)
621                 self.failUnlessEqual(u_root_hash, root_hash)
622                 self.failUnlessEqual(k, 3)
623                 self.failUnlessEqual(N, 10)
624                 self.failUnlessEqual(segsize, 21)
625                 self.failUnlessEqual(datalen, len(CONTENTS))
626                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
627                 sig_material = struct.pack(">BQ32s16s BBQQ",
628                                            0, p._new_seqnum, root_hash, IV,
629                                            k, N, segsize, datalen)
630                 self.failUnless(p._pubkey.verify(sig_material, signature))
631                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
632                 self.failUnless(isinstance(share_hash_chain, dict))
633                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
634                 for shnum,share_hash in share_hash_chain.items():
635                     self.failUnless(isinstance(shnum, int))
636                     self.failUnless(isinstance(share_hash, str))
637                     self.failUnlessEqual(len(share_hash), 32)
638                 self.failUnless(isinstance(block_hash_tree, list))
639                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
640                 self.failUnlessEqual(IV, "SALT"*4)
641                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
642                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
643         d.addCallback(_generated)
644         return d
645
646     # TODO: when we publish to 20 peers, we should get one share per peer on 10
647     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
648     # when we publish to zero peers, we should get a NotEnoughSharesError
649
650 class PublishMixin:
651     def publish_one(self):
652         # publish a file and create shares, which can then be manipulated
653         # later.
654         self.CONTENTS = "New contents go here" * 1000
655         num_peers = 20
656         self._client = FakeClient(num_peers)
657         self._storage = self._client._storage
658         d = self._client.create_mutable_file(self.CONTENTS)
659         def _created(node):
660             self._fn = node
661             self._fn2 = self._client.create_node_from_uri(node.get_uri())
662         d.addCallback(_created)
663         return d
664     def publish_multiple(self):
665         self.CONTENTS = ["Contents 0",
666                          "Contents 1",
667                          "Contents 2",
668                          "Contents 3a",
669                          "Contents 3b"]
670         self._copied_shares = {}
671         num_peers = 20
672         self._client = FakeClient(num_peers)
673         self._storage = self._client._storage
674         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
675         def _created(node):
676             self._fn = node
677             # now create multiple versions of the same file, and accumulate
678             # their shares, so we can mix and match them later.
679             d = defer.succeed(None)
680             d.addCallback(self._copy_shares, 0)
681             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
682             d.addCallback(self._copy_shares, 1)
683             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
684             d.addCallback(self._copy_shares, 2)
685             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
686             d.addCallback(self._copy_shares, 3)
687             # now we replace all the shares with version s3, and upload a new
688             # version to get s4b.
689             rollback = dict([(i,2) for i in range(10)])
690             d.addCallback(lambda res: self._set_versions(rollback))
691             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
692             d.addCallback(self._copy_shares, 4)
693             # we leave the storage in state 4
694             return d
695         d.addCallback(_created)
696         return d
697
698     def _copy_shares(self, ignored, index):
699         shares = self._client._storage._peers
700         # we need a deep copy
701         new_shares = {}
702         for peerid in shares:
703             new_shares[peerid] = {}
704             for shnum in shares[peerid]:
705                 new_shares[peerid][shnum] = shares[peerid][shnum]
706         self._copied_shares[index] = new_shares
707
708     def _set_versions(self, versionmap):
709         # versionmap maps shnums to which version (0,1,2,3,4) we want the
710         # share to be at. Any shnum which is left out of the map will stay at
711         # its current version.
712         shares = self._client._storage._peers
713         oldshares = self._copied_shares
714         for peerid in shares:
715             for shnum in shares[peerid]:
716                 if shnum in versionmap:
717                     index = versionmap[shnum]
718                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
719
720
721 class Servermap(unittest.TestCase, PublishMixin):
722     def setUp(self):
723         return self.publish_one()
724
725     def make_servermap(self, mode=MODE_CHECK, fn=None):
726         if fn is None:
727             fn = self._fn
728         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
729         d = smu.update()
730         return d
731
732     def update_servermap(self, oldmap, mode=MODE_CHECK):
733         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
734         d = smu.update()
735         return d
736
737     def failUnlessOneRecoverable(self, sm, num_shares):
738         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
739         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
740         best = sm.best_recoverable_version()
741         self.failIfEqual(best, None)
742         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
743         self.failUnlessEqual(len(sm.shares_available()), 1)
744         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
745         shnum, peerids = sm.make_sharemap().items()[0]
746         peerid = list(peerids)[0]
747         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
748         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
749         return sm
750
751     def test_basic(self):
752         d = defer.succeed(None)
753         ms = self.make_servermap
754         us = self.update_servermap
755
756         d.addCallback(lambda res: ms(mode=MODE_CHECK))
757         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
758         d.addCallback(lambda res: ms(mode=MODE_WRITE))
759         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
760         d.addCallback(lambda res: ms(mode=MODE_READ))
761         # this more stops at k+epsilon, and epsilon=k, so 6 shares
762         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
763         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
764         # this mode stops at 'k' shares
765         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
766
767         # and can we re-use the same servermap? Note that these are sorted in
768         # increasing order of number of servers queried, since once a server
769         # gets into the servermap, we'll always ask it for an update.
770         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
771         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
772         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
773         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
774         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
775         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
777         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
778         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
779
780         return d
781
782     def test_fetch_privkey(self):
783         d = defer.succeed(None)
784         # use the sibling filenode (which hasn't been used yet), and make
785         # sure it can fetch the privkey. The file is small, so the privkey
786         # will be fetched on the first (query) pass.
787         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
788         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
789
790         # create a new file, which is large enough to knock the privkey out
791         # of the early part of the file
792         LARGE = "These are Larger contents" * 200 # about 5KB
793         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
794         def _created(large_fn):
795             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
796             return self.make_servermap(MODE_WRITE, large_fn2)
797         d.addCallback(_created)
798         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
799         return d
800
801     def test_mark_bad(self):
802         d = defer.succeed(None)
803         ms = self.make_servermap
804         us = self.update_servermap
805
806         d.addCallback(lambda res: ms(mode=MODE_READ))
807         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
808         def _made_map(sm):
809             v = sm.best_recoverable_version()
810             vm = sm.make_versionmap()
811             shares = list(vm[v])
812             self.failUnlessEqual(len(shares), 6)
813             self._corrupted = set()
814             # mark the first 5 shares as corrupt, then update the servermap.
815             # The map should not have the marked shares it in any more, and
816             # new shares should be found to replace the missing ones.
817             for (shnum, peerid, timestamp) in shares:
818                 if shnum < 5:
819                     self._corrupted.add( (peerid, shnum) )
820                     sm.mark_bad_share(peerid, shnum, "")
821             return self.update_servermap(sm, MODE_WRITE)
822         d.addCallback(_made_map)
823         def _check_map(sm):
824             # this should find all 5 shares that weren't marked bad
825             v = sm.best_recoverable_version()
826             vm = sm.make_versionmap()
827             shares = list(vm[v])
828             for (peerid, shnum) in self._corrupted:
829                 peer_shares = sm.shares_on_peer(peerid)
830                 self.failIf(shnum in peer_shares,
831                             "%d was in %s" % (shnum, peer_shares))
832             self.failUnlessEqual(len(shares), 5)
833         d.addCallback(_check_map)
834         return d
835
836     def failUnlessNoneRecoverable(self, sm):
837         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
838         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
839         best = sm.best_recoverable_version()
840         self.failUnlessEqual(best, None)
841         self.failUnlessEqual(len(sm.shares_available()), 0)
842
843     def test_no_shares(self):
844         self._client._storage._peers = {} # delete all shares
845         ms = self.make_servermap
846         d = defer.succeed(None)
847
848         d.addCallback(lambda res: ms(mode=MODE_CHECK))
849         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
850
851         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
852         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
853
854         d.addCallback(lambda res: ms(mode=MODE_WRITE))
855         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
856
857         d.addCallback(lambda res: ms(mode=MODE_READ))
858         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
859
860         return d
861
862     def failUnlessNotQuiteEnough(self, sm):
863         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
864         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
865         best = sm.best_recoverable_version()
866         self.failUnlessEqual(best, None)
867         self.failUnlessEqual(len(sm.shares_available()), 1)
868         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
869         return sm
870
871     def test_not_quite_enough_shares(self):
872         s = self._client._storage
873         ms = self.make_servermap
874         num_shares = len(s._peers)
875         for peerid in s._peers:
876             s._peers[peerid] = {}
877             num_shares -= 1
878             if num_shares == 2:
879                 break
880         # now there ought to be only two shares left
881         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
882
883         d = defer.succeed(None)
884
885         d.addCallback(lambda res: ms(mode=MODE_CHECK))
886         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
887         d.addCallback(lambda sm:
888                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
889         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
890         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
891         d.addCallback(lambda res: ms(mode=MODE_WRITE))
892         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
893         d.addCallback(lambda res: ms(mode=MODE_READ))
894         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
895
896         return d
897
898
899
900 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
901     def setUp(self):
902         return self.publish_one()
903
904     def make_servermap(self, mode=MODE_READ, oldmap=None):
905         if oldmap is None:
906             oldmap = ServerMap()
907         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
908         d = smu.update()
909         return d
910
911     def abbrev_verinfo(self, verinfo):
912         if verinfo is None:
913             return None
914         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
915          offsets_tuple) = verinfo
916         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
917
918     def abbrev_verinfo_dict(self, verinfo_d):
919         output = {}
920         for verinfo,value in verinfo_d.items():
921             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
922              offsets_tuple) = verinfo
923             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
924         return output
925
926     def dump_servermap(self, servermap):
927         print "SERVERMAP", servermap
928         print "RECOVERABLE", [self.abbrev_verinfo(v)
929                               for v in servermap.recoverable_versions()]
930         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
931         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
932
933     def do_download(self, servermap, version=None):
934         if version is None:
935             version = servermap.best_recoverable_version()
936         r = Retrieve(self._fn, servermap, version)
937         return r.download()
938
939     def test_basic(self):
940         d = self.make_servermap()
941         def _do_retrieve(servermap):
942             self._smap = servermap
943             #self.dump_servermap(servermap)
944             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
945             return self.do_download(servermap)
946         d.addCallback(_do_retrieve)
947         def _retrieved(new_contents):
948             self.failUnlessEqual(new_contents, self.CONTENTS)
949         d.addCallback(_retrieved)
950         # we should be able to re-use the same servermap, both with and
951         # without updating it.
952         d.addCallback(lambda res: self.do_download(self._smap))
953         d.addCallback(_retrieved)
954         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
955         d.addCallback(lambda res: self.do_download(self._smap))
956         d.addCallback(_retrieved)
957         # clobbering the pubkey should make the servermap updater re-fetch it
958         def _clobber_pubkey(res):
959             self._fn._pubkey = None
960         d.addCallback(_clobber_pubkey)
961         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
962         d.addCallback(lambda res: self.do_download(self._smap))
963         d.addCallback(_retrieved)
964         return d
965
966     def test_all_shares_vanished(self):
967         d = self.make_servermap()
968         def _remove_shares(servermap):
969             for shares in self._storage._peers.values():
970                 shares.clear()
971             d1 = self.shouldFail(NotEnoughSharesError,
972                                  "test_all_shares_vanished",
973                                  "ran out of peers",
974                                  self.do_download, servermap)
975             return d1
976         d.addCallback(_remove_shares)
977         return d
978
979     def test_no_servers(self):
980         c2 = FakeClient(0)
981         self._fn._client = c2
982         # if there are no servers, then a MODE_READ servermap should come
983         # back empty
984         d = self.make_servermap()
985         def _check_servermap(servermap):
986             self.failUnlessEqual(servermap.best_recoverable_version(), None)
987             self.failIf(servermap.recoverable_versions())
988             self.failIf(servermap.unrecoverable_versions())
989             self.failIf(servermap.all_peers())
990         d.addCallback(_check_servermap)
991         return d
992     test_no_servers.timeout = 15
993
994     def test_no_servers_download(self):
995         c2 = FakeClient(0)
996         self._fn._client = c2
997         d = self.shouldFail(UnrecoverableFileError,
998                             "test_no_servers_download",
999                             "no recoverable versions",
1000                             self._fn.download_best_version)
1001         def _restore(res):
1002             # a failed download that occurs while we aren't connected to
1003             # anybody should not prevent a subsequent download from working.
1004             # This isn't quite the webapi-driven test that #463 wants, but it
1005             # should be close enough.
1006             self._fn._client = self._client
1007             return self._fn.download_best_version()
1008         def _retrieved(new_contents):
1009             self.failUnlessEqual(new_contents, self.CONTENTS)
1010         d.addCallback(_restore)
1011         d.addCallback(_retrieved)
1012         return d
1013     test_no_servers_download.timeout = 15
1014
1015     def _test_corrupt_all(self, offset, substring,
1016                           should_succeed=False, corrupt_early=True,
1017                           failure_checker=None):
1018         d = defer.succeed(None)
1019         if corrupt_early:
1020             d.addCallback(corrupt, self._storage, offset)
1021         d.addCallback(lambda res: self.make_servermap())
1022         if not corrupt_early:
1023             d.addCallback(corrupt, self._storage, offset)
1024         def _do_retrieve(servermap):
1025             ver = servermap.best_recoverable_version()
1026             if ver is None and not should_succeed:
1027                 # no recoverable versions == not succeeding. The problem
1028                 # should be noted in the servermap's list of problems.
1029                 if substring:
1030                     allproblems = [str(f) for f in servermap.problems]
1031                     self.failUnless(substring in "".join(allproblems))
1032                 return servermap
1033             if should_succeed:
1034                 d1 = self._fn.download_version(servermap, ver)
1035                 d1.addCallback(lambda new_contents:
1036                                self.failUnlessEqual(new_contents, self.CONTENTS))
1037             else:
1038                 d1 = self.shouldFail(NotEnoughSharesError,
1039                                      "_corrupt_all(offset=%s)" % (offset,),
1040                                      substring,
1041                                      self._fn.download_version, servermap, ver)
1042             if failure_checker:
1043                 d1.addCallback(failure_checker)
1044             d1.addCallback(lambda res: servermap)
1045             return d1
1046         d.addCallback(_do_retrieve)
1047         return d
1048
1049     def test_corrupt_all_verbyte(self):
1050         # when the version byte is not 0, we hit an assertion error in
1051         # unpack_share().
1052         d = self._test_corrupt_all(0, "AssertionError")
1053         def _check_servermap(servermap):
1054             # and the dump should mention the problems
1055             s = StringIO()
1056             dump = servermap.dump(s).getvalue()
1057             self.failUnless("10 PROBLEMS" in dump, dump)
1058         d.addCallback(_check_servermap)
1059         return d
1060
1061     def test_corrupt_all_seqnum(self):
1062         # a corrupt sequence number will trigger a bad signature
1063         return self._test_corrupt_all(1, "signature is invalid")
1064
1065     def test_corrupt_all_R(self):
1066         # a corrupt root hash will trigger a bad signature
1067         return self._test_corrupt_all(9, "signature is invalid")
1068
1069     def test_corrupt_all_IV(self):
1070         # a corrupt salt/IV will trigger a bad signature
1071         return self._test_corrupt_all(41, "signature is invalid")
1072
1073     def test_corrupt_all_k(self):
1074         # a corrupt 'k' will trigger a bad signature
1075         return self._test_corrupt_all(57, "signature is invalid")
1076
1077     def test_corrupt_all_N(self):
1078         # a corrupt 'N' will trigger a bad signature
1079         return self._test_corrupt_all(58, "signature is invalid")
1080
1081     def test_corrupt_all_segsize(self):
1082         # a corrupt segsize will trigger a bad signature
1083         return self._test_corrupt_all(59, "signature is invalid")
1084
1085     def test_corrupt_all_datalen(self):
1086         # a corrupt data length will trigger a bad signature
1087         return self._test_corrupt_all(67, "signature is invalid")
1088
1089     def test_corrupt_all_pubkey(self):
1090         # a corrupt pubkey won't match the URI's fingerprint. We need to
1091         # remove the pubkey from the filenode, or else it won't bother trying
1092         # to update it.
1093         self._fn._pubkey = None
1094         return self._test_corrupt_all("pubkey",
1095                                       "pubkey doesn't match fingerprint")
1096
1097     def test_corrupt_all_sig(self):
1098         # a corrupt signature is a bad one
1099         # the signature runs from about [543:799], depending upon the length
1100         # of the pubkey
1101         return self._test_corrupt_all("signature", "signature is invalid")
1102
1103     def test_corrupt_all_share_hash_chain_number(self):
1104         # a corrupt share hash chain entry will show up as a bad hash. If we
1105         # mangle the first byte, that will look like a bad hash number,
1106         # causing an IndexError
1107         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1108
1109     def test_corrupt_all_share_hash_chain_hash(self):
1110         # a corrupt share hash chain entry will show up as a bad hash. If we
1111         # mangle a few bytes in, that will look like a bad hash.
1112         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1113
1114     def test_corrupt_all_block_hash_tree(self):
1115         return self._test_corrupt_all("block_hash_tree",
1116                                       "block hash tree failure")
1117
1118     def test_corrupt_all_block(self):
1119         return self._test_corrupt_all("share_data", "block hash tree failure")
1120
1121     def test_corrupt_all_encprivkey(self):
1122         # a corrupted privkey won't even be noticed by the reader, only by a
1123         # writer.
1124         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1125
1126
1127     def test_corrupt_all_seqnum_late(self):
1128         # corrupting the seqnum between mapupdate and retrieve should result
1129         # in NotEnoughSharesError, since each share will look invalid
1130         def _check(res):
1131             f = res[0]
1132             self.failUnless(f.check(NotEnoughSharesError))
1133             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1134         return self._test_corrupt_all(1, "ran out of peers",
1135                                       corrupt_early=False,
1136                                       failure_checker=_check)
1137
1138     def test_corrupt_all_block_hash_tree_late(self):
1139         def _check(res):
1140             f = res[0]
1141             self.failUnless(f.check(NotEnoughSharesError))
1142         return self._test_corrupt_all("block_hash_tree",
1143                                       "block hash tree failure",
1144                                       corrupt_early=False,
1145                                       failure_checker=_check)
1146
1147
1148     def test_corrupt_all_block_late(self):
1149         def _check(res):
1150             f = res[0]
1151             self.failUnless(f.check(NotEnoughSharesError))
1152         return self._test_corrupt_all("share_data", "block hash tree failure",
1153                                       corrupt_early=False,
1154                                       failure_checker=_check)
1155
1156
1157     def test_basic_pubkey_at_end(self):
1158         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1159         # download to succeed but forcing a bunch of retries first. Note that
1160         # this is rather pessimistic: our Retrieve process will throw away
1161         # the whole share if the pubkey is bad, even though the rest of the
1162         # share might be good.
1163
1164         self._fn._pubkey = None
1165         k = self._fn.get_required_shares()
1166         N = self._fn.get_total_shares()
1167         d = defer.succeed(None)
1168         d.addCallback(corrupt, self._storage, "pubkey",
1169                       shnums_to_corrupt=range(0, N-k))
1170         d.addCallback(lambda res: self.make_servermap())
1171         def _do_retrieve(servermap):
1172             self.failUnless(servermap.problems)
1173             self.failUnless("pubkey doesn't match fingerprint"
1174                             in str(servermap.problems[0]))
1175             ver = servermap.best_recoverable_version()
1176             r = Retrieve(self._fn, servermap, ver)
1177             return r.download()
1178         d.addCallback(_do_retrieve)
1179         d.addCallback(lambda new_contents:
1180                       self.failUnlessEqual(new_contents, self.CONTENTS))
1181         return d
1182
1183     def test_corrupt_some(self):
1184         # corrupt the data of first five shares (so the servermap thinks
1185         # they're good but retrieve marks them as bad), so that the
1186         # MODE_READ set of 6 will be insufficient, forcing node.download to
1187         # retry with more servers.
1188         corrupt(None, self._storage, "share_data", range(5))
1189         d = self.make_servermap()
1190         def _do_retrieve(servermap):
1191             ver = servermap.best_recoverable_version()
1192             self.failUnless(ver)
1193             return self._fn.download_best_version()
1194         d.addCallback(_do_retrieve)
1195         d.addCallback(lambda new_contents:
1196                       self.failUnlessEqual(new_contents, self.CONTENTS))
1197         return d
1198
1199     def test_download_fails(self):
1200         corrupt(None, self._storage, "signature")
1201         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1202                             "no recoverable versions",
1203                             self._fn.download_best_version)
1204         return d
1205
1206
1207 class CheckerMixin:
1208     def check_good(self, r, where):
1209         self.failUnless(r.is_healthy(), where)
1210         return r
1211
1212     def check_bad(self, r, where):
1213         self.failIf(r.is_healthy(), where)
1214         return r
1215
1216     def check_expected_failure(self, r, expected_exception, substring, where):
1217         for (peerid, storage_index, shnum, f) in r.problems:
1218             if f.check(expected_exception):
1219                 self.failUnless(substring in str(f),
1220                                 "%s: substring '%s' not in '%s'" %
1221                                 (where, substring, str(f)))
1222                 return
1223         self.fail("%s: didn't see expected exception %s in problems %s" %
1224                   (where, expected_exception, r.problems))
1225
1226
1227 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1228     def setUp(self):
1229         return self.publish_one()
1230
1231
1232     def test_check_good(self):
1233         d = self._fn.check(Monitor())
1234         d.addCallback(self.check_good, "test_check_good")
1235         return d
1236
1237     def test_check_no_shares(self):
1238         for shares in self._storage._peers.values():
1239             shares.clear()
1240         d = self._fn.check(Monitor())
1241         d.addCallback(self.check_bad, "test_check_no_shares")
1242         return d
1243
1244     def test_check_not_enough_shares(self):
1245         for shares in self._storage._peers.values():
1246             for shnum in shares.keys():
1247                 if shnum > 0:
1248                     del shares[shnum]
1249         d = self._fn.check(Monitor())
1250         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1251         return d
1252
1253     def test_check_all_bad_sig(self):
1254         corrupt(None, self._storage, 1) # bad sig
1255         d = self._fn.check(Monitor())
1256         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1257         return d
1258
1259     def test_check_all_bad_blocks(self):
1260         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1261         # the Checker won't notice this.. it doesn't look at actual data
1262         d = self._fn.check(Monitor())
1263         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1264         return d
1265
1266     def test_verify_good(self):
1267         d = self._fn.check(Monitor(), verify=True)
1268         d.addCallback(self.check_good, "test_verify_good")
1269         return d
1270
1271     def test_verify_all_bad_sig(self):
1272         corrupt(None, self._storage, 1) # bad sig
1273         d = self._fn.check(Monitor(), verify=True)
1274         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1275         return d
1276
1277     def test_verify_one_bad_sig(self):
1278         corrupt(None, self._storage, 1, [9]) # bad sig
1279         d = self._fn.check(Monitor(), verify=True)
1280         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1281         return d
1282
1283     def test_verify_one_bad_block(self):
1284         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1285         # the Verifier *will* notice this, since it examines every byte
1286         d = self._fn.check(Monitor(), verify=True)
1287         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1288         d.addCallback(self.check_expected_failure,
1289                       CorruptShareError, "block hash tree failure",
1290                       "test_verify_one_bad_block")
1291         return d
1292
1293     def test_verify_one_bad_sharehash(self):
1294         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1295         d = self._fn.check(Monitor(), verify=True)
1296         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1297         d.addCallback(self.check_expected_failure,
1298                       CorruptShareError, "corrupt hashes",
1299                       "test_verify_one_bad_sharehash")
1300         return d
1301
1302     def test_verify_one_bad_encprivkey(self):
1303         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1304         d = self._fn.check(Monitor(), verify=True)
1305         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1306         d.addCallback(self.check_expected_failure,
1307                       CorruptShareError, "invalid privkey",
1308                       "test_verify_one_bad_encprivkey")
1309         return d
1310
1311     def test_verify_one_bad_encprivkey_uncheckable(self):
1312         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1313         readonly_fn = self._fn.get_readonly()
1314         # a read-only node has no way to validate the privkey
1315         d = readonly_fn.check(Monitor(), verify=True)
1316         d.addCallback(self.check_good,
1317                       "test_verify_one_bad_encprivkey_uncheckable")
1318         return d
1319
1320 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1321
1322     def get_shares(self, s):
1323         all_shares = {} # maps (peerid, shnum) to share data
1324         for peerid in s._peers:
1325             shares = s._peers[peerid]
1326             for shnum in shares:
1327                 data = shares[shnum]
1328                 all_shares[ (peerid, shnum) ] = data
1329         return all_shares
1330
1331     def copy_shares(self, ignored=None):
1332         self.old_shares.append(self.get_shares(self._storage))
1333
1334     def test_repair_nop(self):
1335         self.old_shares = []
1336         d = self.publish_one()
1337         d.addCallback(self.copy_shares)
1338         d.addCallback(lambda res: self._fn.check(Monitor()))
1339         d.addCallback(lambda check_results: self._fn.repair(check_results))
1340         def _check_results(rres):
1341             self.failUnless(IRepairResults.providedBy(rres))
1342             # TODO: examine results
1343
1344             self.copy_shares()
1345
1346             initial_shares = self.old_shares[0]
1347             new_shares = self.old_shares[1]
1348             # TODO: this really shouldn't change anything. When we implement
1349             # a "minimal-bandwidth" repairer", change this test to assert:
1350             #self.failUnlessEqual(new_shares, initial_shares)
1351
1352             # all shares should be in the same place as before
1353             self.failUnlessEqual(set(initial_shares.keys()),
1354                                  set(new_shares.keys()))
1355             # but they should all be at a newer seqnum. The IV will be
1356             # different, so the roothash will be too.
1357             for key in initial_shares:
1358                 (version0,
1359                  seqnum0,
1360                  root_hash0,
1361                  IV0,
1362                  k0, N0, segsize0, datalen0,
1363                  o0) = unpack_header(initial_shares[key])
1364                 (version1,
1365                  seqnum1,
1366                  root_hash1,
1367                  IV1,
1368                  k1, N1, segsize1, datalen1,
1369                  o1) = unpack_header(new_shares[key])
1370                 self.failUnlessEqual(version0, version1)
1371                 self.failUnlessEqual(seqnum0+1, seqnum1)
1372                 self.failUnlessEqual(k0, k1)
1373                 self.failUnlessEqual(N0, N1)
1374                 self.failUnlessEqual(segsize0, segsize1)
1375                 self.failUnlessEqual(datalen0, datalen1)
1376         d.addCallback(_check_results)
1377         return d
1378
1379     def failIfSharesChanged(self, ignored=None):
1380         old_shares = self.old_shares[-2]
1381         current_shares = self.old_shares[-1]
1382         self.failUnlessEqual(old_shares, current_shares)
1383
1384     def test_merge(self):
1385         self.old_shares = []
1386         d = self.publish_multiple()
1387         # repair will refuse to merge multiple highest seqnums unless you
1388         # pass force=True
1389         d.addCallback(lambda res:
1390                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1391                                           1:4,3:4,5:4,7:4,9:4}))
1392         d.addCallback(self.copy_shares)
1393         d.addCallback(lambda res: self._fn.check(Monitor()))
1394         def _try_repair(check_results):
1395             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1396             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1397                                  self._fn.repair, check_results)
1398             d2.addCallback(self.copy_shares)
1399             d2.addCallback(self.failIfSharesChanged)
1400             d2.addCallback(lambda res: check_results)
1401             return d2
1402         d.addCallback(_try_repair)
1403         d.addCallback(lambda check_results:
1404                       self._fn.repair(check_results, force=True))
1405         # this should give us 10 shares of the highest roothash
1406         def _check_repair_results(rres):
1407             pass # TODO
1408         d.addCallback(_check_repair_results)
1409         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1410         def _check_smap(smap):
1411             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1412             self.failIf(smap.unrecoverable_versions())
1413             # now, which should have won?
1414             roothash_s4a = self.get_roothash_for(3)
1415             roothash_s4b = self.get_roothash_for(4)
1416             if roothash_s4b > roothash_s4a:
1417                 expected_contents = self.CONTENTS[4]
1418             else:
1419                 expected_contents = self.CONTENTS[3]
1420             new_versionid = smap.best_recoverable_version()
1421             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1422             d2 = self._fn.download_version(smap, new_versionid)
1423             d2.addCallback(self.failUnlessEqual, expected_contents)
1424             return d2
1425         d.addCallback(_check_smap)
1426         return d
1427
1428     def test_non_merge(self):
1429         self.old_shares = []
1430         d = self.publish_multiple()
1431         # repair should not refuse a repair that doesn't need to merge. In
1432         # this case, we combine v2 with v3. The repair should ignore v2 and
1433         # copy v3 into a new v5.
1434         d.addCallback(lambda res:
1435                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1436                                           1:3,3:3,5:3,7:3,9:3}))
1437         d.addCallback(lambda res: self._fn.check(Monitor()))
1438         d.addCallback(lambda check_results: self._fn.repair(check_results))
1439         # this should give us 10 shares of v3
1440         def _check_repair_results(rres):
1441             pass # TODO
1442         d.addCallback(_check_repair_results)
1443         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1444         def _check_smap(smap):
1445             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1446             self.failIf(smap.unrecoverable_versions())
1447             # now, which should have won?
1448             roothash_s4a = self.get_roothash_for(3)
1449             expected_contents = self.CONTENTS[3]
1450             new_versionid = smap.best_recoverable_version()
1451             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1452             d2 = self._fn.download_version(smap, new_versionid)
1453             d2.addCallback(self.failUnlessEqual, expected_contents)
1454             return d2
1455         d.addCallback(_check_smap)
1456         return d
1457
1458     def get_roothash_for(self, index):
1459         # return the roothash for the first share we see in the saved set
1460         shares = self._copied_shares[index]
1461         for peerid in shares:
1462             for shnum in shares[peerid]:
1463                 share = shares[peerid][shnum]
1464                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1465                           unpack_header(share)
1466                 return root_hash
1467
1468     def test_check_and_repair_readcap(self):
1469         # we can't currently repair from a mutable readcap: #625
1470         self.old_shares = []
1471         d = self.publish_one()
1472         d.addCallback(self.copy_shares)
1473         def _get_readcap(res):
1474             self._fn3 = self._fn.get_readonly()
1475             # also delete some shares
1476             for peerid,shares in self._storage._peers.items():
1477                 shares.pop(0, None)
1478         d.addCallback(_get_readcap)
1479         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1480         def _check_results(crr):
1481             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1482             # we should detect the unhealthy, but skip over mutable-readcap
1483             # repairs until #625 is fixed
1484             self.failIf(crr.get_pre_repair_results().is_healthy())
1485             self.failIf(crr.get_repair_attempted())
1486             self.failIf(crr.get_post_repair_results().is_healthy())
1487         d.addCallback(_check_results)
1488         return d
1489
1490 class MultipleEncodings(unittest.TestCase):
1491     def setUp(self):
1492         self.CONTENTS = "New contents go here"
1493         num_peers = 20
1494         self._client = FakeClient(num_peers)
1495         self._storage = self._client._storage
1496         d = self._client.create_mutable_file(self.CONTENTS)
1497         def _created(node):
1498             self._fn = node
1499         d.addCallback(_created)
1500         return d
1501
1502     def _encode(self, k, n, data):
1503         # encode 'data' into a peerid->shares dict.
1504
1505         fn2 = FastMutableFileNode(self._client)
1506         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1507         # and _fingerprint
1508         fn = self._fn
1509         fn2.init_from_uri(fn.get_uri())
1510         # then we copy over other fields that are normally fetched from the
1511         # existing shares
1512         fn2._pubkey = fn._pubkey
1513         fn2._privkey = fn._privkey
1514         fn2._encprivkey = fn._encprivkey
1515         # and set the encoding parameters to something completely different
1516         fn2._required_shares = k
1517         fn2._total_shares = n
1518
1519         s = self._client._storage
1520         s._peers = {} # clear existing storage
1521         p2 = Publish(fn2, None)
1522         d = p2.publish(data)
1523         def _published(res):
1524             shares = s._peers
1525             s._peers = {}
1526             return shares
1527         d.addCallback(_published)
1528         return d
1529
1530     def make_servermap(self, mode=MODE_READ, oldmap=None):
1531         if oldmap is None:
1532             oldmap = ServerMap()
1533         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1534         d = smu.update()
1535         return d
1536
1537     def test_multiple_encodings(self):
1538         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1539         # then mix up the shares, to make sure that download survives seeing
1540         # a variety of encodings. This is actually kind of tricky to set up.
1541
1542         contents1 = "Contents for encoding 1 (3-of-10) go here"
1543         contents2 = "Contents for encoding 2 (4-of-9) go here"
1544         contents3 = "Contents for encoding 3 (4-of-7) go here"
1545
1546         # we make a retrieval object that doesn't know what encoding
1547         # parameters to use
1548         fn3 = FastMutableFileNode(self._client)
1549         fn3.init_from_uri(self._fn.get_uri())
1550
1551         # now we upload a file through fn1, and grab its shares
1552         d = self._encode(3, 10, contents1)
1553         def _encoded_1(shares):
1554             self._shares1 = shares
1555         d.addCallback(_encoded_1)
1556         d.addCallback(lambda res: self._encode(4, 9, contents2))
1557         def _encoded_2(shares):
1558             self._shares2 = shares
1559         d.addCallback(_encoded_2)
1560         d.addCallback(lambda res: self._encode(4, 7, contents3))
1561         def _encoded_3(shares):
1562             self._shares3 = shares
1563         d.addCallback(_encoded_3)
1564
1565         def _merge(res):
1566             log.msg("merging sharelists")
1567             # we merge the shares from the two sets, leaving each shnum in
1568             # its original location, but using a share from set1 or set2
1569             # according to the following sequence:
1570             #
1571             #  4-of-9  a  s2
1572             #  4-of-9  b  s2
1573             #  4-of-7  c   s3
1574             #  4-of-9  d  s2
1575             #  3-of-9  e s1
1576             #  3-of-9  f s1
1577             #  3-of-9  g s1
1578             #  4-of-9  h  s2
1579             #
1580             # so that neither form can be recovered until fetch [f], at which
1581             # point version-s1 (the 3-of-10 form) should be recoverable. If
1582             # the implementation latches on to the first version it sees,
1583             # then s2 will be recoverable at fetch [g].
1584
1585             # Later, when we implement code that handles multiple versions,
1586             # we can use this framework to assert that all recoverable
1587             # versions are retrieved, and test that 'epsilon' does its job
1588
1589             places = [2, 2, 3, 2, 1, 1, 1, 2]
1590
1591             sharemap = {}
1592             sb = self._client.get_storage_broker()
1593
1594             for peerid in sorted(sb.get_all_serverids()):
1595                 peerid_s = shortnodeid_b2a(peerid)
1596                 for shnum in self._shares1.get(peerid, {}):
1597                     if shnum < len(places):
1598                         which = places[shnum]
1599                     else:
1600                         which = "x"
1601                     self._client._storage._peers[peerid] = peers = {}
1602                     in_1 = shnum in self._shares1[peerid]
1603                     in_2 = shnum in self._shares2.get(peerid, {})
1604                     in_3 = shnum in self._shares3.get(peerid, {})
1605                     #print peerid_s, shnum, which, in_1, in_2, in_3
1606                     if which == 1:
1607                         if in_1:
1608                             peers[shnum] = self._shares1[peerid][shnum]
1609                             sharemap[shnum] = peerid
1610                     elif which == 2:
1611                         if in_2:
1612                             peers[shnum] = self._shares2[peerid][shnum]
1613                             sharemap[shnum] = peerid
1614                     elif which == 3:
1615                         if in_3:
1616                             peers[shnum] = self._shares3[peerid][shnum]
1617                             sharemap[shnum] = peerid
1618
1619             # we don't bother placing any other shares
1620             # now sort the sequence so that share 0 is returned first
1621             new_sequence = [sharemap[shnum]
1622                             for shnum in sorted(sharemap.keys())]
1623             self._client._storage._sequence = new_sequence
1624             log.msg("merge done")
1625         d.addCallback(_merge)
1626         d.addCallback(lambda res: fn3.download_best_version())
1627         def _retrieved(new_contents):
1628             # the current specified behavior is "first version recoverable"
1629             self.failUnlessEqual(new_contents, contents1)
1630         d.addCallback(_retrieved)
1631         return d
1632
1633
1634 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1635
1636     def setUp(self):
1637         return self.publish_multiple()
1638
1639     def test_multiple_versions(self):
1640         # if we see a mix of versions in the grid, download_best_version
1641         # should get the latest one
1642         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1643         d = self._fn.download_best_version()
1644         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1645         # and the checker should report problems
1646         d.addCallback(lambda res: self._fn.check(Monitor()))
1647         d.addCallback(self.check_bad, "test_multiple_versions")
1648
1649         # but if everything is at version 2, that's what we should download
1650         d.addCallback(lambda res:
1651                       self._set_versions(dict([(i,2) for i in range(10)])))
1652         d.addCallback(lambda res: self._fn.download_best_version())
1653         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1654         # if exactly one share is at version 3, we should still get v2
1655         d.addCallback(lambda res:
1656                       self._set_versions({0:3}))
1657         d.addCallback(lambda res: self._fn.download_best_version())
1658         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1659         # but the servermap should see the unrecoverable version. This
1660         # depends upon the single newer share being queried early.
1661         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1662         def _check_smap(smap):
1663             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1664             newer = smap.unrecoverable_newer_versions()
1665             self.failUnlessEqual(len(newer), 1)
1666             verinfo, health = newer.items()[0]
1667             self.failUnlessEqual(verinfo[0], 4)
1668             self.failUnlessEqual(health, (1,3))
1669             self.failIf(smap.needs_merge())
1670         d.addCallback(_check_smap)
1671         # if we have a mix of two parallel versions (s4a and s4b), we could
1672         # recover either
1673         d.addCallback(lambda res:
1674                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1675                                           1:4,3:4,5:4,7:4,9:4}))
1676         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1677         def _check_smap_mixed(smap):
1678             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1679             newer = smap.unrecoverable_newer_versions()
1680             self.failUnlessEqual(len(newer), 0)
1681             self.failUnless(smap.needs_merge())
1682         d.addCallback(_check_smap_mixed)
1683         d.addCallback(lambda res: self._fn.download_best_version())
1684         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1685                                                   res == self.CONTENTS[4]))
1686         return d
1687
1688     def test_replace(self):
1689         # if we see a mix of versions in the grid, we should be able to
1690         # replace them all with a newer version
1691
1692         # if exactly one share is at version 3, we should download (and
1693         # replace) v2, and the result should be v4. Note that the index we
1694         # give to _set_versions is different than the sequence number.
1695         target = dict([(i,2) for i in range(10)]) # seqnum3
1696         target[0] = 3 # seqnum4
1697         self._set_versions(target)
1698
1699         def _modify(oldversion, servermap, first_time):
1700             return oldversion + " modified"
1701         d = self._fn.modify(_modify)
1702         d.addCallback(lambda res: self._fn.download_best_version())
1703         expected = self.CONTENTS[2] + " modified"
1704         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1705         # and the servermap should indicate that the outlier was replaced too
1706         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1707         def _check_smap(smap):
1708             self.failUnlessEqual(smap.highest_seqnum(), 5)
1709             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1710             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1711         d.addCallback(_check_smap)
1712         return d
1713
1714
1715 class Utils(unittest.TestCase):
1716     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1717         # we compare this against sets of integers
1718         x = set(range(x_start, x_start+x_length))
1719         y = set(range(y_start, y_start+y_length))
1720         should_be_inside = x.issubset(y)
1721         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1722                                                          y_start, y_length),
1723                              str((x_start, x_length, y_start, y_length)))
1724
1725     def test_cache_inside(self):
1726         c = ResponseCache()
1727         x_start = 10
1728         x_length = 5
1729         for y_start in range(8, 17):
1730             for y_length in range(8):
1731                 self._do_inside(c, x_start, x_length, y_start, y_length)
1732
1733     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1734         # we compare this against sets of integers
1735         x = set(range(x_start, x_start+x_length))
1736         y = set(range(y_start, y_start+y_length))
1737         overlap = bool(x.intersection(y))
1738         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1739                                                       y_start, y_length),
1740                              str((x_start, x_length, y_start, y_length)))
1741
1742     def test_cache_overlap(self):
1743         c = ResponseCache()
1744         x_start = 10
1745         x_length = 5
1746         for y_start in range(8, 17):
1747             for y_length in range(8):
1748                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1749
1750     def test_cache(self):
1751         c = ResponseCache()
1752         # xdata = base62.b2a(os.urandom(100))[:100]
1753         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1754         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1755         nope = (None, None)
1756         c.add("v1", 1, 0, xdata, "time0")
1757         c.add("v1", 1, 2000, ydata, "time1")
1758         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1759         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1760         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1761         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1762         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1763         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1764         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1765         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1766         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1767         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1768         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1769         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1770         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1771         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1772         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1773         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1774         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1775         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1776
1777         # optional: join fragments
1778         c = ResponseCache()
1779         c.add("v1", 1, 0, xdata[:10], "time0")
1780         c.add("v1", 1, 10, xdata[10:20], "time1")
1781         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1782
1783 class Exceptions(unittest.TestCase):
1784     def test_repr(self):
1785         nmde = NeedMoreDataError(100, 50, 100)
1786         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1787         ucwe = UncoordinatedWriteError()
1788         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1789
1790 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1791 # instances which always succeed. So we need a less-fake one.
1792
1793 class IntentionalError(Exception):
1794     pass
1795
1796 class LocalWrapper:
1797     def __init__(self, original):
1798         self.original = original
1799         self.broken = False
1800         self.post_call_notifier = None
1801     def callRemote(self, methname, *args, **kwargs):
1802         def _call():
1803             if self.broken:
1804                 raise IntentionalError("I was asked to break")
1805             meth = getattr(self.original, "remote_" + methname)
1806             return meth(*args, **kwargs)
1807         d = fireEventually()
1808         d.addCallback(lambda res: _call())
1809         if self.post_call_notifier:
1810             d.addCallback(self.post_call_notifier, methname)
1811         return d
1812
1813 class LessFakeClient(FakeClient):
1814
1815     def __init__(self, basedir, num_peers=10):
1816         self._num_peers = num_peers
1817         peerids = [tagged_hash("peerid", "%d" % i)[:20]
1818                    for i in range(self._num_peers)]
1819         self.storage_broker = StorageFarmBroker(None, True)
1820         for peerid in peerids:
1821             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1822             make_dirs(peerdir)
1823             ss = StorageServer(peerdir, peerid)
1824             lw = LocalWrapper(ss)
1825             self.storage_broker.test_add_server(peerid, lw)
1826         self.nodeid = "fakenodeid"
1827
1828
1829 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1830     def test_publish_surprise(self):
1831         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1832         self.client = LessFakeClient(basedir)
1833         d = self.client.create_mutable_file("contents 1")
1834         def _created(n):
1835             d = defer.succeed(None)
1836             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1837             def _got_smap1(smap):
1838                 # stash the old state of the file
1839                 self.old_map = smap
1840             d.addCallback(_got_smap1)
1841             # then modify the file, leaving the old map untouched
1842             d.addCallback(lambda res: log.msg("starting winning write"))
1843             d.addCallback(lambda res: n.overwrite("contents 2"))
1844             # now attempt to modify the file with the old servermap. This
1845             # will look just like an uncoordinated write, in which every
1846             # single share got updated between our mapupdate and our publish
1847             d.addCallback(lambda res: log.msg("starting doomed write"))
1848             d.addCallback(lambda res:
1849                           self.shouldFail(UncoordinatedWriteError,
1850                                           "test_publish_surprise", None,
1851                                           n.upload,
1852                                           "contents 2a", self.old_map))
1853             return d
1854         d.addCallback(_created)
1855         return d
1856
1857     def test_retrieve_surprise(self):
1858         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1859         self.client = LessFakeClient(basedir)
1860         d = self.client.create_mutable_file("contents 1")
1861         def _created(n):
1862             d = defer.succeed(None)
1863             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1864             def _got_smap1(smap):
1865                 # stash the old state of the file
1866                 self.old_map = smap
1867             d.addCallback(_got_smap1)
1868             # then modify the file, leaving the old map untouched
1869             d.addCallback(lambda res: log.msg("starting winning write"))
1870             d.addCallback(lambda res: n.overwrite("contents 2"))
1871             # now attempt to retrieve the old version with the old servermap.
1872             # This will look like someone has changed the file since we
1873             # updated the servermap.
1874             d.addCallback(lambda res: n._cache._clear())
1875             d.addCallback(lambda res: log.msg("starting doomed read"))
1876             d.addCallback(lambda res:
1877                           self.shouldFail(NotEnoughSharesError,
1878                                           "test_retrieve_surprise",
1879                                           "ran out of peers: have 0 shares (k=3)",
1880                                           n.download_version,
1881                                           self.old_map,
1882                                           self.old_map.best_recoverable_version(),
1883                                           ))
1884             return d
1885         d.addCallback(_created)
1886         return d
1887
1888     def test_unexpected_shares(self):
1889         # upload the file, take a servermap, shut down one of the servers,
1890         # upload it again (causing shares to appear on a new server), then
1891         # upload using the old servermap. The last upload should fail with an
1892         # UncoordinatedWriteError, because of the shares that didn't appear
1893         # in the servermap.
1894         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1895         self.client = LessFakeClient(basedir)
1896         d = self.client.create_mutable_file("contents 1")
1897         def _created(n):
1898             d = defer.succeed(None)
1899             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1900             def _got_smap1(smap):
1901                 # stash the old state of the file
1902                 self.old_map = smap
1903                 # now shut down one of the servers
1904                 peer0 = list(smap.make_sharemap()[0])[0]
1905                 self.client.debug_remove_connection(peer0)
1906                 # then modify the file, leaving the old map untouched
1907                 log.msg("starting winning write")
1908                 return n.overwrite("contents 2")
1909             d.addCallback(_got_smap1)
1910             # now attempt to modify the file with the old servermap. This
1911             # will look just like an uncoordinated write, in which every
1912             # single share got updated between our mapupdate and our publish
1913             d.addCallback(lambda res: log.msg("starting doomed write"))
1914             d.addCallback(lambda res:
1915                           self.shouldFail(UncoordinatedWriteError,
1916                                           "test_surprise", None,
1917                                           n.upload,
1918                                           "contents 2a", self.old_map))
1919             return d
1920         d.addCallback(_created)
1921         return d
1922
1923     def test_bad_server(self):
1924         # Break one server, then create the file: the initial publish should
1925         # complete with an alternate server. Breaking a second server should
1926         # not prevent an update from succeeding either.
1927         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1928         self.client = LessFakeClient(basedir, 20)
1929         # to make sure that one of the initial peers is broken, we have to
1930         # get creative. We create the keys, so we can figure out the storage
1931         # index, but we hold off on doing the initial publish until we've
1932         # broken the server on which the first share wants to be stored.
1933         n = FastMutableFileNode(self.client)
1934         d = defer.succeed(None)
1935         d.addCallback(n._generate_pubprivkeys, keysize=522)
1936         d.addCallback(n._generated)
1937         def _break_peer0(res):
1938             si = n.get_storage_index()
1939             peerlist = self.client.storage_broker.get_servers_for_index(si)
1940             peerid0, connection0 = peerlist[0]
1941             peerid1, connection1 = peerlist[1]
1942             connection0.broken = True
1943             self.connection1 = connection1
1944         d.addCallback(_break_peer0)
1945         # now let the initial publish finally happen
1946         d.addCallback(lambda res: n._upload("contents 1", None))
1947         # that ought to work
1948         d.addCallback(lambda res: n.download_best_version())
1949         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1950         # now break the second peer
1951         def _break_peer1(res):
1952             self.connection1.broken = True
1953         d.addCallback(_break_peer1)
1954         d.addCallback(lambda res: n.overwrite("contents 2"))
1955         # that ought to work too
1956         d.addCallback(lambda res: n.download_best_version())
1957         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1958         def _explain_error(f):
1959             print f
1960             if f.check(NotEnoughServersError):
1961                 print "first_error:", f.value.first_error
1962             return f
1963         d.addErrback(_explain_error)
1964         return d
1965
1966     def test_bad_server_overlap(self):
1967         # like test_bad_server, but with no extra unused servers to fall back
1968         # upon. This means that we must re-use a server which we've already
1969         # used. If we don't remember the fact that we sent them one share
1970         # already, we'll mistakenly think we're experiencing an
1971         # UncoordinatedWriteError.
1972
1973         # Break one server, then create the file: the initial publish should
1974         # complete with an alternate server. Breaking a second server should
1975         # not prevent an update from succeeding either.
1976         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1977         self.client = LessFakeClient(basedir, 10)
1978         sb = self.client.get_storage_broker()
1979
1980         peerids = list(sb.get_all_serverids())
1981         self.client.debug_break_connection(peerids[0])
1982
1983         d = self.client.create_mutable_file("contents 1")
1984         def _created(n):
1985             d = n.download_best_version()
1986             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1987             # now break one of the remaining servers
1988             def _break_second_server(res):
1989                 self.client.debug_break_connection(peerids[1])
1990             d.addCallback(_break_second_server)
1991             d.addCallback(lambda res: n.overwrite("contents 2"))
1992             # that ought to work too
1993             d.addCallback(lambda res: n.download_best_version())
1994             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1995             return d
1996         d.addCallback(_created)
1997         return d
1998
1999     def test_publish_all_servers_bad(self):
2000         # Break all servers: the publish should fail
2001         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
2002         self.client = LessFakeClient(basedir, 20)
2003         sb = self.client.get_storage_broker()
2004         for peerid in sb.get_all_serverids():
2005             self.client.debug_break_connection(peerid)
2006         d = self.shouldFail(NotEnoughServersError,
2007                             "test_publish_all_servers_bad",
2008                             "Ran out of non-bad servers",
2009                             self.client.create_mutable_file, "contents")
2010         return d
2011
2012     def test_publish_no_servers(self):
2013         # no servers at all: the publish should fail
2014         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
2015         self.client = LessFakeClient(basedir, 0)
2016         d = self.shouldFail(NotEnoughServersError,
2017                             "test_publish_no_servers",
2018                             "Ran out of non-bad servers",
2019                             self.client.create_mutable_file, "contents")
2020         return d
2021     test_publish_no_servers.timeout = 30
2022
2023
2024     def test_privkey_query_error(self):
2025         # when a servermap is updated with MODE_WRITE, it tries to get the
2026         # privkey. Something might go wrong during this query attempt.
2027         self.client = FakeClient(20)
2028         # we need some contents that are large enough to push the privkey out
2029         # of the early part of the file
2030         LARGE = "These are Larger contents" * 200 # about 5KB
2031         d = self.client.create_mutable_file(LARGE)
2032         def _created(n):
2033             self.uri = n.get_uri()
2034             self.n2 = self.client.create_node_from_uri(self.uri)
2035             # we start by doing a map update to figure out which is the first
2036             # server.
2037             return n.get_servermap(MODE_WRITE)
2038         d.addCallback(_created)
2039         d.addCallback(lambda res: fireEventually(res))
2040         def _got_smap1(smap):
2041             peer0 = list(smap.make_sharemap()[0])[0]
2042             # we tell the server to respond to this peer first, so that it
2043             # will be asked for the privkey first
2044             self.client._storage._sequence = [peer0]
2045             # now we make the peer fail their second query
2046             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2047         d.addCallback(_got_smap1)
2048         # now we update a servermap from a new node (which doesn't have the
2049         # privkey yet, forcing it to use a separate privkey query). Each
2050         # query response will trigger a privkey query, and since we're using
2051         # _sequence to make the peer0 response come back first, we'll send it
2052         # a privkey query first, and _sequence will again ensure that the
2053         # peer0 query will also come back before the others, and then
2054         # _special_answers will make sure that the query raises an exception.
2055         # The whole point of these hijinks is to exercise the code in
2056         # _privkey_query_failed. Note that the map-update will succeed, since
2057         # we'll just get a copy from one of the other shares.
2058         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2059         # Using FakeStorage._sequence means there will be read requests still
2060         # floating around.. wait for them to retire
2061         def _cancel_timer(res):
2062             if self.client._storage._pending_timer:
2063                 self.client._storage._pending_timer.cancel()
2064             return res
2065         d.addBoth(_cancel_timer)
2066         return d
2067
2068     def test_privkey_query_missing(self):
2069         # like test_privkey_query_error, but the shares are deleted by the
2070         # second query, instead of raising an exception.
2071         self.client = FakeClient(20)
2072         LARGE = "These are Larger contents" * 200 # about 5KB
2073         d = self.client.create_mutable_file(LARGE)
2074         def _created(n):
2075             self.uri = n.get_uri()
2076             self.n2 = self.client.create_node_from_uri(self.uri)
2077             return n.get_servermap(MODE_WRITE)
2078         d.addCallback(_created)
2079         d.addCallback(lambda res: fireEventually(res))
2080         def _got_smap1(smap):
2081             peer0 = list(smap.make_sharemap()[0])[0]
2082             self.client._storage._sequence = [peer0]
2083             self.client._storage._special_answers[peer0] = ["normal", "none"]
2084         d.addCallback(_got_smap1)
2085         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2086         def _cancel_timer(res):
2087             if self.client._storage._pending_timer:
2088                 self.client._storage._pending_timer.cancel()
2089             return res
2090         d.addBoth(_cancel_timer)
2091         return d