]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
mutable publish: if we are surprised by shares that match what we would have written...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, storage
8 from allmydata.immutable import download
9 from allmydata.util import base32, idlib
10 from allmydata.util.idlib import shortnodeid_b2a
11 from allmydata.util.hashutil import tagged_hash
12 from allmydata.util.fileutil import make_dirs
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
14      FileTooLargeError, NotEnoughSharesError, IRepairResults
15 from allmydata.monitor import Monitor
16 from allmydata.test.common import ShouldFailMixin
17 from foolscap.eventual import eventually, fireEventually
18 from foolscap.logging import log
19 import sha
20
21 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
22 from allmydata.mutable.common import DictOfSets, ResponseCache, \
23      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
24      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
25      NotEnoughServersError, CorruptShareError
26 from allmydata.mutable.retrieve import Retrieve
27 from allmydata.mutable.publish import Publish
28 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
29 from allmydata.mutable.layout import unpack_header, unpack_share
30 from allmydata.mutable.repairer import MustForceRepairError
31
32 import common_util as testutil
33
34 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
35 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
36 # of this.
37
38 class FastMutableFileNode(MutableFileNode):
39     SIGNATURE_KEY_SIZE = 522
40
41 # this "FakeStorage" exists to put the share data in RAM and avoid using real
42 # network connections, both to speed up the tests and to reduce the amount of
43 # non-mutable.py code being exercised.
44
45 class FakeStorage:
46     # this class replaces the collection of storage servers, allowing the
47     # tests to examine and manipulate the published shares. It also lets us
48     # control the order in which read queries are answered, to exercise more
49     # of the error-handling code in Retrieve .
50     #
51     # Note that we ignore the storage index: this FakeStorage instance can
52     # only be used for a single storage index.
53
54
55     def __init__(self):
56         self._peers = {}
57         # _sequence is used to cause the responses to occur in a specific
58         # order. If it is in use, then we will defer queries instead of
59         # answering them right away, accumulating the Deferreds in a dict. We
60         # don't know exactly how many queries we'll get, so exactly one
61         # second after the first query arrives, we will release them all (in
62         # order).
63         self._sequence = None
64         self._pending = {}
65         self._pending_timer = None
66         self._special_answers = {}
67
68     def read(self, peerid, storage_index):
69         shares = self._peers.get(peerid, {})
70         if self._special_answers.get(peerid, []):
71             mode = self._special_answers[peerid].pop(0)
72             if mode == "fail":
73                 shares = failure.Failure(IntentionalError())
74             elif mode == "none":
75                 shares = {}
76             elif mode == "normal":
77                 pass
78         if self._sequence is None:
79             return defer.succeed(shares)
80         d = defer.Deferred()
81         if not self._pending:
82             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
83         self._pending[peerid] = (d, shares)
84         return d
85
86     def _fire_readers(self):
87         self._pending_timer = None
88         pending = self._pending
89         self._pending = {}
90         extra = []
91         for peerid in self._sequence:
92             if peerid in pending:
93                 d, shares = pending.pop(peerid)
94                 eventually(d.callback, shares)
95         for (d, shares) in pending.values():
96             eventually(d.callback, shares)
97
98     def write(self, peerid, storage_index, shnum, offset, data):
99         if peerid not in self._peers:
100             self._peers[peerid] = {}
101         shares = self._peers[peerid]
102         f = StringIO()
103         f.write(shares.get(shnum, ""))
104         f.seek(offset)
105         f.write(data)
106         shares[shnum] = f.getvalue()
107
108
109 class FakeStorageServer:
110     def __init__(self, peerid, storage):
111         self.peerid = peerid
112         self.storage = storage
113         self.queries = 0
114     def callRemote(self, methname, *args, **kwargs):
115         def _call():
116             meth = getattr(self, methname)
117             return meth(*args, **kwargs)
118         d = fireEventually()
119         d.addCallback(lambda res: _call())
120         return d
121     def callRemoteOnly(self, methname, *args, **kwargs):
122         d = self.callRemote(methname, *args, **kwargs)
123         d.addBoth(lambda ignore: None)
124         pass
125
126     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
127         pass
128
129     def slot_readv(self, storage_index, shnums, readv):
130         d = self.storage.read(self.peerid, storage_index)
131         def _read(shares):
132             response = {}
133             for shnum in shares:
134                 if shnums and shnum not in shnums:
135                     continue
136                 vector = response[shnum] = []
137                 for (offset, length) in readv:
138                     assert isinstance(offset, (int, long)), offset
139                     assert isinstance(length, (int, long)), length
140                     vector.append(shares[shnum][offset:offset+length])
141             return response
142         d.addCallback(_read)
143         return d
144
145     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
146                                         tw_vectors, read_vector):
147         # always-pass: parrot the test vectors back to them.
148         readv = {}
149         for shnum, (testv, writev, new_length) in tw_vectors.items():
150             for (offset, length, op, specimen) in testv:
151                 assert op in ("le", "eq", "ge")
152             # TODO: this isn't right, the read is controlled by read_vector,
153             # not by testv
154             readv[shnum] = [ specimen
155                              for (offset, length, op, specimen)
156                              in testv ]
157             for (offset, data) in writev:
158                 self.storage.write(self.peerid, storage_index, shnum,
159                                    offset, data)
160         answer = (True, readv)
161         return fireEventually(answer)
162
163
164 # our "FakeClient" has just enough functionality of the real Client to let
165 # the tests run.
166
167 class FakeClient:
168     mutable_file_node_class = FastMutableFileNode
169
170     def __init__(self, num_peers=10):
171         self._storage = FakeStorage()
172         self._num_peers = num_peers
173         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
174                          for i in range(self._num_peers)]
175         self._connections = dict([(peerid, FakeStorageServer(peerid,
176                                                              self._storage))
177                                   for peerid in self._peerids])
178         self.nodeid = "fakenodeid"
179
180     def get_encoding_parameters(self):
181         return {"k": 3, "n": 10}
182
183     def log(self, msg, **kw):
184         return log.msg(msg, **kw)
185
186     def get_renewal_secret(self):
187         return "I hereby permit you to renew my files"
188     def get_cancel_secret(self):
189         return "I hereby permit you to cancel my leases"
190
191     def create_mutable_file(self, contents=""):
192         n = self.mutable_file_node_class(self)
193         d = n.create(contents)
194         d.addCallback(lambda res: n)
195         return d
196
197     def notify_retrieve(self, r):
198         pass
199     def notify_publish(self, p, size):
200         pass
201     def notify_mapupdate(self, u):
202         pass
203
204     def create_node_from_uri(self, u):
205         u = IURI(u)
206         assert IMutableFileURI.providedBy(u), u
207         res = self.mutable_file_node_class(self).init_from_uri(u)
208         return res
209
210     def get_permuted_peers(self, service_name, key):
211         """
212         @return: list of (peerid, connection,)
213         """
214         results = []
215         for (peerid, connection) in self._connections.items():
216             assert isinstance(peerid, str)
217             permuted = sha.new(key + peerid).digest()
218             results.append((permuted, peerid, connection))
219         results.sort()
220         results = [ (r[1],r[2]) for r in results]
221         return results
222
223     def upload(self, uploadable):
224         assert IUploadable.providedBy(uploadable)
225         d = uploadable.get_size()
226         d.addCallback(lambda length: uploadable.read(length))
227         #d.addCallback(self.create_mutable_file)
228         def _got_data(datav):
229             data = "".join(datav)
230             #newnode = FastMutableFileNode(self)
231             return uri.LiteralFileURI(data)
232         d.addCallback(_got_data)
233         return d
234
235
236 def flip_bit(original, byte_offset):
237     return (original[:byte_offset] +
238             chr(ord(original[byte_offset]) ^ 0x01) +
239             original[byte_offset+1:])
240
241 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
242     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
243     # list of shnums to corrupt.
244     for peerid in s._peers:
245         shares = s._peers[peerid]
246         for shnum in shares:
247             if (shnums_to_corrupt is not None
248                 and shnum not in shnums_to_corrupt):
249                 continue
250             data = shares[shnum]
251             (version,
252              seqnum,
253              root_hash,
254              IV,
255              k, N, segsize, datalen,
256              o) = unpack_header(data)
257             if isinstance(offset, tuple):
258                 offset1, offset2 = offset
259             else:
260                 offset1 = offset
261                 offset2 = 0
262             if offset1 == "pubkey":
263                 real_offset = 107
264             elif offset1 in o:
265                 real_offset = o[offset1]
266             else:
267                 real_offset = offset1
268             real_offset = int(real_offset) + offset2 + offset_offset
269             assert isinstance(real_offset, int), offset
270             shares[shnum] = flip_bit(data, real_offset)
271     return res
272
273 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
274     def setUp(self):
275         self.client = FakeClient()
276
277     def test_create(self):
278         d = self.client.create_mutable_file()
279         def _created(n):
280             self.failUnless(isinstance(n, FastMutableFileNode))
281             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
282             peer0 = self.client._peerids[0]
283             shnums = self.client._storage._peers[peer0].keys()
284             self.failUnlessEqual(len(shnums), 1)
285         d.addCallback(_created)
286         return d
287
288     def test_serialize(self):
289         n = MutableFileNode(self.client)
290         calls = []
291         def _callback(*args, **kwargs):
292             self.failUnlessEqual(args, (4,) )
293             self.failUnlessEqual(kwargs, {"foo": 5})
294             calls.append(1)
295             return 6
296         d = n._do_serialized(_callback, 4, foo=5)
297         def _check_callback(res):
298             self.failUnlessEqual(res, 6)
299             self.failUnlessEqual(calls, [1])
300         d.addCallback(_check_callback)
301
302         def _errback():
303             raise ValueError("heya")
304         d.addCallback(lambda res:
305                       self.shouldFail(ValueError, "_check_errback", "heya",
306                                       n._do_serialized, _errback))
307         return d
308
309     def test_upload_and_download(self):
310         d = self.client.create_mutable_file()
311         def _created(n):
312             d = defer.succeed(None)
313             d.addCallback(lambda res: n.get_servermap(MODE_READ))
314             d.addCallback(lambda smap: smap.dump(StringIO()))
315             d.addCallback(lambda sio:
316                           self.failUnless("3-of-10" in sio.getvalue()))
317             d.addCallback(lambda res: n.overwrite("contents 1"))
318             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
319             d.addCallback(lambda res: n.download_best_version())
320             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
321             d.addCallback(lambda res: n.get_size_of_best_version())
322             d.addCallback(lambda size:
323                           self.failUnlessEqual(size, len("contents 1")))
324             d.addCallback(lambda res: n.overwrite("contents 2"))
325             d.addCallback(lambda res: n.download_best_version())
326             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
327             d.addCallback(lambda res: n.download(download.Data()))
328             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
329             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
330             d.addCallback(lambda smap: n.upload("contents 3", smap))
331             d.addCallback(lambda res: n.download_best_version())
332             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
333             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
334             d.addCallback(lambda smap:
335                           n.download_version(smap,
336                                              smap.best_recoverable_version()))
337             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
338             # test a file that is large enough to overcome the
339             # mapupdate-to-retrieve data caching (i.e. make the shares larger
340             # than the default readsize, which is 2000 bytes). A 15kB file
341             # will have 5kB shares.
342             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
343             d.addCallback(lambda res: n.download_best_version())
344             d.addCallback(lambda res:
345                           self.failUnlessEqual(res, "large size file" * 1000))
346             return d
347         d.addCallback(_created)
348         return d
349
350     def test_create_with_initial_contents(self):
351         d = self.client.create_mutable_file("contents 1")
352         def _created(n):
353             d = n.download_best_version()
354             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
355             d.addCallback(lambda res: n.overwrite("contents 2"))
356             d.addCallback(lambda res: n.download_best_version())
357             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
358             return d
359         d.addCallback(_created)
360         return d
361
362     def test_create_with_too_large_contents(self):
363         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
364         d = self.shouldFail(FileTooLargeError, "too_large",
365                             "SDMF is limited to one segment, and %d > %d" %
366                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
367                             self.client.create_mutable_file, BIG)
368         d.addCallback(lambda res: self.client.create_mutable_file("small"))
369         def _created(n):
370             return self.shouldFail(FileTooLargeError, "too_large_2",
371                                    "SDMF is limited to one segment, and %d > %d" %
372                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
373                                    n.overwrite, BIG)
374         d.addCallback(_created)
375         return d
376
377     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
378         d = n.get_servermap(MODE_READ)
379         d.addCallback(lambda servermap: servermap.best_recoverable_version())
380         d.addCallback(lambda verinfo:
381                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
382         return d
383
384     def test_modify(self):
385         def _modifier(old_contents, servermap, first_time):
386             return old_contents + "line2"
387         def _non_modifier(old_contents, servermap, first_time):
388             return old_contents
389         def _none_modifier(old_contents, servermap, first_time):
390             return None
391         def _error_modifier(old_contents, servermap, first_time):
392             raise ValueError("oops")
393         def _toobig_modifier(old_contents, servermap, first_time):
394             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
395         calls = []
396         def _ucw_error_modifier(old_contents, servermap, first_time):
397             # simulate an UncoordinatedWriteError once
398             calls.append(1)
399             if len(calls) <= 1:
400                 raise UncoordinatedWriteError("simulated")
401             return old_contents + "line3"
402         def _ucw_error_non_modifier(old_contents, servermap, first_time):
403             # simulate an UncoordinatedWriteError once, and don't actually
404             # modify the contents on subsequent invocations
405             calls.append(1)
406             if len(calls) <= 1:
407                 raise UncoordinatedWriteError("simulated")
408             return old_contents
409
410         d = self.client.create_mutable_file("line1")
411         def _created(n):
412             d = n.modify(_modifier)
413             d.addCallback(lambda res: n.download_best_version())
414             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
415             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
416
417             d.addCallback(lambda res: n.modify(_non_modifier))
418             d.addCallback(lambda res: n.download_best_version())
419             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
420             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
421
422             d.addCallback(lambda res: n.modify(_none_modifier))
423             d.addCallback(lambda res: n.download_best_version())
424             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
425             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
426
427             d.addCallback(lambda res:
428                           self.shouldFail(ValueError, "error_modifier", None,
429                                           n.modify, _error_modifier))
430             d.addCallback(lambda res: n.download_best_version())
431             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
432             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
433
434             d.addCallback(lambda res:
435                           self.shouldFail(FileTooLargeError, "toobig_modifier",
436                                           "SDMF is limited to one segment",
437                                           n.modify, _toobig_modifier))
438             d.addCallback(lambda res: n.download_best_version())
439             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
440             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
441
442             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
443             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
444             d.addCallback(lambda res: n.download_best_version())
445             d.addCallback(lambda res: self.failUnlessEqual(res,
446                                                            "line1line2line3"))
447             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
448
449             def _reset_ucw_error_modifier(res):
450                 calls[:] = []
451                 return res
452             d.addCallback(_reset_ucw_error_modifier)
453
454             # in practice, this n.modify call should publish twice: the first
455             # one gets a UCWE, the second does not. But our test jig (in
456             # which the modifier raises the UCWE) skips over the first one,
457             # so in this test there will be only one publish, and the seqnum
458             # will only be one larger than the previous test, not two (i.e. 4
459             # instead of 5).
460             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
461             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
462             d.addCallback(lambda res: n.download_best_version())
463             d.addCallback(lambda res: self.failUnlessEqual(res,
464                                                            "line1line2line3"))
465             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
466
467             return d
468         d.addCallback(_created)
469         return d
470
471     def test_modify_backoffer(self):
472         def _modifier(old_contents, servermap, first_time):
473             return old_contents + "line2"
474         calls = []
475         def _ucw_error_modifier(old_contents, servermap, first_time):
476             # simulate an UncoordinatedWriteError once
477             calls.append(1)
478             if len(calls) <= 1:
479                 raise UncoordinatedWriteError("simulated")
480             return old_contents + "line3"
481         def _always_ucw_error_modifier(old_contents, servermap, first_time):
482             raise UncoordinatedWriteError("simulated")
483         def _backoff_stopper(node, f):
484             return f
485         def _backoff_pauser(node, f):
486             d = defer.Deferred()
487             reactor.callLater(0.5, d.callback, None)
488             return d
489
490         # the give-up-er will hit its maximum retry count quickly
491         giveuper = BackoffAgent()
492         giveuper._delay = 0.1
493         giveuper.factor = 1
494
495         d = self.client.create_mutable_file("line1")
496         def _created(n):
497             d = n.modify(_modifier)
498             d.addCallback(lambda res: n.download_best_version())
499             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
500             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
501
502             d.addCallback(lambda res:
503                           self.shouldFail(UncoordinatedWriteError,
504                                           "_backoff_stopper", None,
505                                           n.modify, _ucw_error_modifier,
506                                           _backoff_stopper))
507             d.addCallback(lambda res: n.download_best_version())
508             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
509             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
510
511             def _reset_ucw_error_modifier(res):
512                 calls[:] = []
513                 return res
514             d.addCallback(_reset_ucw_error_modifier)
515             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
516                                                _backoff_pauser))
517             d.addCallback(lambda res: n.download_best_version())
518             d.addCallback(lambda res: self.failUnlessEqual(res,
519                                                            "line1line2line3"))
520             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
521
522             d.addCallback(lambda res:
523                           self.shouldFail(UncoordinatedWriteError,
524                                           "giveuper", None,
525                                           n.modify, _always_ucw_error_modifier,
526                                           giveuper.delay))
527             d.addCallback(lambda res: n.download_best_version())
528             d.addCallback(lambda res: self.failUnlessEqual(res,
529                                                            "line1line2line3"))
530             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
531
532             return d
533         d.addCallback(_created)
534         return d
535
536     def test_upload_and_download_full_size_keys(self):
537         self.client.mutable_file_node_class = MutableFileNode
538         d = self.client.create_mutable_file()
539         def _created(n):
540             d = defer.succeed(None)
541             d.addCallback(lambda res: n.get_servermap(MODE_READ))
542             d.addCallback(lambda smap: smap.dump(StringIO()))
543             d.addCallback(lambda sio:
544                           self.failUnless("3-of-10" in sio.getvalue()))
545             d.addCallback(lambda res: n.overwrite("contents 1"))
546             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
547             d.addCallback(lambda res: n.download_best_version())
548             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
549             d.addCallback(lambda res: n.overwrite("contents 2"))
550             d.addCallback(lambda res: n.download_best_version())
551             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
552             d.addCallback(lambda res: n.download(download.Data()))
553             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
554             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
555             d.addCallback(lambda smap: n.upload("contents 3", smap))
556             d.addCallback(lambda res: n.download_best_version())
557             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
558             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
559             d.addCallback(lambda smap:
560                           n.download_version(smap,
561                                              smap.best_recoverable_version()))
562             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
563             return d
564         d.addCallback(_created)
565         return d
566
567
568 class MakeShares(unittest.TestCase):
569     def test_encrypt(self):
570         c = FakeClient()
571         fn = FastMutableFileNode(c)
572         CONTENTS = "some initial contents"
573         d = fn.create(CONTENTS)
574         def _created(res):
575             p = Publish(fn, None)
576             p.salt = "SALT" * 4
577             p.readkey = "\x00" * 16
578             p.newdata = CONTENTS
579             p.required_shares = 3
580             p.total_shares = 10
581             p.setup_encoding_parameters()
582             return p._encrypt_and_encode()
583         d.addCallback(_created)
584         def _done(shares_and_shareids):
585             (shares, share_ids) = shares_and_shareids
586             self.failUnlessEqual(len(shares), 10)
587             for sh in shares:
588                 self.failUnless(isinstance(sh, str))
589                 self.failUnlessEqual(len(sh), 7)
590             self.failUnlessEqual(len(share_ids), 10)
591         d.addCallback(_done)
592         return d
593
594     def test_generate(self):
595         c = FakeClient()
596         fn = FastMutableFileNode(c)
597         CONTENTS = "some initial contents"
598         d = fn.create(CONTENTS)
599         def _created(res):
600             p = Publish(fn, None)
601             self._p = p
602             p.newdata = CONTENTS
603             p.required_shares = 3
604             p.total_shares = 10
605             p.setup_encoding_parameters()
606             p._new_seqnum = 3
607             p.salt = "SALT" * 4
608             # make some fake shares
609             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
610             p._privkey = fn.get_privkey()
611             p._encprivkey = fn.get_encprivkey()
612             p._pubkey = fn.get_pubkey()
613             return p._generate_shares(shares_and_ids)
614         d.addCallback(_created)
615         def _generated(res):
616             p = self._p
617             final_shares = p.shares
618             root_hash = p.root_hash
619             self.failUnlessEqual(len(root_hash), 32)
620             self.failUnless(isinstance(final_shares, dict))
621             self.failUnlessEqual(len(final_shares), 10)
622             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
623             for i,sh in final_shares.items():
624                 self.failUnless(isinstance(sh, str))
625                 # feed the share through the unpacker as a sanity-check
626                 pieces = unpack_share(sh)
627                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
628                  pubkey, signature, share_hash_chain, block_hash_tree,
629                  share_data, enc_privkey) = pieces
630                 self.failUnlessEqual(u_seqnum, 3)
631                 self.failUnlessEqual(u_root_hash, root_hash)
632                 self.failUnlessEqual(k, 3)
633                 self.failUnlessEqual(N, 10)
634                 self.failUnlessEqual(segsize, 21)
635                 self.failUnlessEqual(datalen, len(CONTENTS))
636                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
637                 sig_material = struct.pack(">BQ32s16s BBQQ",
638                                            0, p._new_seqnum, root_hash, IV,
639                                            k, N, segsize, datalen)
640                 self.failUnless(p._pubkey.verify(sig_material, signature))
641                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
642                 self.failUnless(isinstance(share_hash_chain, dict))
643                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
644                 for shnum,share_hash in share_hash_chain.items():
645                     self.failUnless(isinstance(shnum, int))
646                     self.failUnless(isinstance(share_hash, str))
647                     self.failUnlessEqual(len(share_hash), 32)
648                 self.failUnless(isinstance(block_hash_tree, list))
649                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
650                 self.failUnlessEqual(IV, "SALT"*4)
651                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
652                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
653         d.addCallback(_generated)
654         return d
655
656     # TODO: when we publish to 20 peers, we should get one share per peer on 10
657     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
658     # when we publish to zero peers, we should get a NotEnoughSharesError
659
660 class PublishMixin:
661     def publish_one(self):
662         # publish a file and create shares, which can then be manipulated
663         # later.
664         self.CONTENTS = "New contents go here" * 1000
665         num_peers = 20
666         self._client = FakeClient(num_peers)
667         self._storage = self._client._storage
668         d = self._client.create_mutable_file(self.CONTENTS)
669         def _created(node):
670             self._fn = node
671             self._fn2 = self._client.create_node_from_uri(node.get_uri())
672         d.addCallback(_created)
673         return d
674     def publish_multiple(self):
675         self.CONTENTS = ["Contents 0",
676                          "Contents 1",
677                          "Contents 2",
678                          "Contents 3a",
679                          "Contents 3b"]
680         self._copied_shares = {}
681         num_peers = 20
682         self._client = FakeClient(num_peers)
683         self._storage = self._client._storage
684         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
685         def _created(node):
686             self._fn = node
687             # now create multiple versions of the same file, and accumulate
688             # their shares, so we can mix and match them later.
689             d = defer.succeed(None)
690             d.addCallback(self._copy_shares, 0)
691             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
692             d.addCallback(self._copy_shares, 1)
693             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
694             d.addCallback(self._copy_shares, 2)
695             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
696             d.addCallback(self._copy_shares, 3)
697             # now we replace all the shares with version s3, and upload a new
698             # version to get s4b.
699             rollback = dict([(i,2) for i in range(10)])
700             d.addCallback(lambda res: self._set_versions(rollback))
701             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
702             d.addCallback(self._copy_shares, 4)
703             # we leave the storage in state 4
704             return d
705         d.addCallback(_created)
706         return d
707
708     def _copy_shares(self, ignored, index):
709         shares = self._client._storage._peers
710         # we need a deep copy
711         new_shares = {}
712         for peerid in shares:
713             new_shares[peerid] = {}
714             for shnum in shares[peerid]:
715                 new_shares[peerid][shnum] = shares[peerid][shnum]
716         self._copied_shares[index] = new_shares
717
718     def _set_versions(self, versionmap):
719         # versionmap maps shnums to which version (0,1,2,3,4) we want the
720         # share to be at. Any shnum which is left out of the map will stay at
721         # its current version.
722         shares = self._client._storage._peers
723         oldshares = self._copied_shares
724         for peerid in shares:
725             for shnum in shares[peerid]:
726                 if shnum in versionmap:
727                     index = versionmap[shnum]
728                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
729
730
731 class Servermap(unittest.TestCase, PublishMixin):
732     def setUp(self):
733         return self.publish_one()
734
735     def make_servermap(self, mode=MODE_CHECK, fn=None):
736         if fn is None:
737             fn = self._fn
738         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
739         d = smu.update()
740         return d
741
742     def update_servermap(self, oldmap, mode=MODE_CHECK):
743         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
744         d = smu.update()
745         return d
746
747     def failUnlessOneRecoverable(self, sm, num_shares):
748         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
749         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
750         best = sm.best_recoverable_version()
751         self.failIfEqual(best, None)
752         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
753         self.failUnlessEqual(len(sm.shares_available()), 1)
754         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
755         shnum, peerids = sm.make_sharemap().items()[0]
756         peerid = list(peerids)[0]
757         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
758         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
759         return sm
760
761     def test_basic(self):
762         d = defer.succeed(None)
763         ms = self.make_servermap
764         us = self.update_servermap
765
766         d.addCallback(lambda res: ms(mode=MODE_CHECK))
767         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
768         d.addCallback(lambda res: ms(mode=MODE_WRITE))
769         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
770         d.addCallback(lambda res: ms(mode=MODE_READ))
771         # this more stops at k+epsilon, and epsilon=k, so 6 shares
772         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
773         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
774         # this mode stops at 'k' shares
775         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
776
777         # and can we re-use the same servermap? Note that these are sorted in
778         # increasing order of number of servers queried, since once a server
779         # gets into the servermap, we'll always ask it for an update.
780         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
781         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
782         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
783         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
784         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
785         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
786         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
787         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
788         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
789
790         return d
791
792     def test_fetch_privkey(self):
793         d = defer.succeed(None)
794         # use the sibling filenode (which hasn't been used yet), and make
795         # sure it can fetch the privkey. The file is small, so the privkey
796         # will be fetched on the first (query) pass.
797         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
798         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
799
800         # create a new file, which is large enough to knock the privkey out
801         # of the early part of the file
802         LARGE = "These are Larger contents" * 200 # about 5KB
803         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
804         def _created(large_fn):
805             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
806             return self.make_servermap(MODE_WRITE, large_fn2)
807         d.addCallback(_created)
808         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
809         return d
810
811     def test_mark_bad(self):
812         d = defer.succeed(None)
813         ms = self.make_servermap
814         us = self.update_servermap
815
816         d.addCallback(lambda res: ms(mode=MODE_READ))
817         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
818         def _made_map(sm):
819             v = sm.best_recoverable_version()
820             vm = sm.make_versionmap()
821             shares = list(vm[v])
822             self.failUnlessEqual(len(shares), 6)
823             self._corrupted = set()
824             # mark the first 5 shares as corrupt, then update the servermap.
825             # The map should not have the marked shares it in any more, and
826             # new shares should be found to replace the missing ones.
827             for (shnum, peerid, timestamp) in shares:
828                 if shnum < 5:
829                     self._corrupted.add( (peerid, shnum) )
830                     sm.mark_bad_share(peerid, shnum, "")
831             return self.update_servermap(sm, MODE_WRITE)
832         d.addCallback(_made_map)
833         def _check_map(sm):
834             # this should find all 5 shares that weren't marked bad
835             v = sm.best_recoverable_version()
836             vm = sm.make_versionmap()
837             shares = list(vm[v])
838             for (peerid, shnum) in self._corrupted:
839                 peer_shares = sm.shares_on_peer(peerid)
840                 self.failIf(shnum in peer_shares,
841                             "%d was in %s" % (shnum, peer_shares))
842             self.failUnlessEqual(len(shares), 5)
843         d.addCallback(_check_map)
844         return d
845
846     def failUnlessNoneRecoverable(self, sm):
847         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
848         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
849         best = sm.best_recoverable_version()
850         self.failUnlessEqual(best, None)
851         self.failUnlessEqual(len(sm.shares_available()), 0)
852
853     def test_no_shares(self):
854         self._client._storage._peers = {} # delete all shares
855         ms = self.make_servermap
856         d = defer.succeed(None)
857
858         d.addCallback(lambda res: ms(mode=MODE_CHECK))
859         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
860
861         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
862         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
863
864         d.addCallback(lambda res: ms(mode=MODE_WRITE))
865         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
866
867         d.addCallback(lambda res: ms(mode=MODE_READ))
868         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
869
870         return d
871
872     def failUnlessNotQuiteEnough(self, sm):
873         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
874         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
875         best = sm.best_recoverable_version()
876         self.failUnlessEqual(best, None)
877         self.failUnlessEqual(len(sm.shares_available()), 1)
878         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
879         return sm
880
881     def test_not_quite_enough_shares(self):
882         s = self._client._storage
883         ms = self.make_servermap
884         num_shares = len(s._peers)
885         for peerid in s._peers:
886             s._peers[peerid] = {}
887             num_shares -= 1
888             if num_shares == 2:
889                 break
890         # now there ought to be only two shares left
891         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
892
893         d = defer.succeed(None)
894
895         d.addCallback(lambda res: ms(mode=MODE_CHECK))
896         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
897         d.addCallback(lambda sm:
898                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
899         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
900         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
901         d.addCallback(lambda res: ms(mode=MODE_WRITE))
902         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
903         d.addCallback(lambda res: ms(mode=MODE_READ))
904         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
905
906         return d
907
908
909
910 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
911     def setUp(self):
912         return self.publish_one()
913
914     def make_servermap(self, mode=MODE_READ, oldmap=None):
915         if oldmap is None:
916             oldmap = ServerMap()
917         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
918         d = smu.update()
919         return d
920
921     def abbrev_verinfo(self, verinfo):
922         if verinfo is None:
923             return None
924         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
925          offsets_tuple) = verinfo
926         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
927
928     def abbrev_verinfo_dict(self, verinfo_d):
929         output = {}
930         for verinfo,value in verinfo_d.items():
931             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
932              offsets_tuple) = verinfo
933             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
934         return output
935
936     def dump_servermap(self, servermap):
937         print "SERVERMAP", servermap
938         print "RECOVERABLE", [self.abbrev_verinfo(v)
939                               for v in servermap.recoverable_versions()]
940         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
941         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
942
943     def do_download(self, servermap, version=None):
944         if version is None:
945             version = servermap.best_recoverable_version()
946         r = Retrieve(self._fn, servermap, version)
947         return r.download()
948
949     def test_basic(self):
950         d = self.make_servermap()
951         def _do_retrieve(servermap):
952             self._smap = servermap
953             #self.dump_servermap(servermap)
954             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
955             return self.do_download(servermap)
956         d.addCallback(_do_retrieve)
957         def _retrieved(new_contents):
958             self.failUnlessEqual(new_contents, self.CONTENTS)
959         d.addCallback(_retrieved)
960         # we should be able to re-use the same servermap, both with and
961         # without updating it.
962         d.addCallback(lambda res: self.do_download(self._smap))
963         d.addCallback(_retrieved)
964         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
965         d.addCallback(lambda res: self.do_download(self._smap))
966         d.addCallback(_retrieved)
967         # clobbering the pubkey should make the servermap updater re-fetch it
968         def _clobber_pubkey(res):
969             self._fn._pubkey = None
970         d.addCallback(_clobber_pubkey)
971         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
972         d.addCallback(lambda res: self.do_download(self._smap))
973         d.addCallback(_retrieved)
974         return d
975
976     def test_all_shares_vanished(self):
977         d = self.make_servermap()
978         def _remove_shares(servermap):
979             for shares in self._storage._peers.values():
980                 shares.clear()
981             d1 = self.shouldFail(NotEnoughSharesError,
982                                  "test_all_shares_vanished",
983                                  "ran out of peers",
984                                  self.do_download, servermap)
985             return d1
986         d.addCallback(_remove_shares)
987         return d
988
989     def test_no_servers(self):
990         c2 = FakeClient(0)
991         self._fn._client = c2
992         # if there are no servers, then a MODE_READ servermap should come
993         # back empty
994         d = self.make_servermap()
995         def _check_servermap(servermap):
996             self.failUnlessEqual(servermap.best_recoverable_version(), None)
997             self.failIf(servermap.recoverable_versions())
998             self.failIf(servermap.unrecoverable_versions())
999             self.failIf(servermap.all_peers())
1000         d.addCallback(_check_servermap)
1001         return d
1002     test_no_servers.timeout = 15
1003
1004     def test_no_servers_download(self):
1005         c2 = FakeClient(0)
1006         self._fn._client = c2
1007         d = self.shouldFail(UnrecoverableFileError,
1008                             "test_no_servers_download",
1009                             "no recoverable versions",
1010                             self._fn.download_best_version)
1011         def _restore(res):
1012             # a failed download that occurs while we aren't connected to
1013             # anybody should not prevent a subsequent download from working.
1014             # This isn't quite the webapi-driven test that #463 wants, but it
1015             # should be close enough.
1016             self._fn._client = self._client
1017             return self._fn.download_best_version()
1018         def _retrieved(new_contents):
1019             self.failUnlessEqual(new_contents, self.CONTENTS)
1020         d.addCallback(_restore)
1021         d.addCallback(_retrieved)
1022         return d
1023     test_no_servers_download.timeout = 15
1024
1025     def _test_corrupt_all(self, offset, substring,
1026                           should_succeed=False, corrupt_early=True,
1027                           failure_checker=None):
1028         d = defer.succeed(None)
1029         if corrupt_early:
1030             d.addCallback(corrupt, self._storage, offset)
1031         d.addCallback(lambda res: self.make_servermap())
1032         if not corrupt_early:
1033             d.addCallback(corrupt, self._storage, offset)
1034         def _do_retrieve(servermap):
1035             ver = servermap.best_recoverable_version()
1036             if ver is None and not should_succeed:
1037                 # no recoverable versions == not succeeding. The problem
1038                 # should be noted in the servermap's list of problems.
1039                 if substring:
1040                     allproblems = [str(f) for f in servermap.problems]
1041                     self.failUnless(substring in "".join(allproblems))
1042                 return servermap
1043             if should_succeed:
1044                 d1 = self._fn.download_version(servermap, ver)
1045                 d1.addCallback(lambda new_contents:
1046                                self.failUnlessEqual(new_contents, self.CONTENTS))
1047             else:
1048                 d1 = self.shouldFail(NotEnoughSharesError,
1049                                      "_corrupt_all(offset=%s)" % (offset,),
1050                                      substring,
1051                                      self._fn.download_version, servermap, ver)
1052             if failure_checker:
1053                 d1.addCallback(failure_checker)
1054             d1.addCallback(lambda res: servermap)
1055             return d1
1056         d.addCallback(_do_retrieve)
1057         return d
1058
1059     def test_corrupt_all_verbyte(self):
1060         # when the version byte is not 0, we hit an assertion error in
1061         # unpack_share().
1062         d = self._test_corrupt_all(0, "AssertionError")
1063         def _check_servermap(servermap):
1064             # and the dump should mention the problems
1065             s = StringIO()
1066             dump = servermap.dump(s).getvalue()
1067             self.failUnless("10 PROBLEMS" in dump, dump)
1068         d.addCallback(_check_servermap)
1069         return d
1070
1071     def test_corrupt_all_seqnum(self):
1072         # a corrupt sequence number will trigger a bad signature
1073         return self._test_corrupt_all(1, "signature is invalid")
1074
1075     def test_corrupt_all_R(self):
1076         # a corrupt root hash will trigger a bad signature
1077         return self._test_corrupt_all(9, "signature is invalid")
1078
1079     def test_corrupt_all_IV(self):
1080         # a corrupt salt/IV will trigger a bad signature
1081         return self._test_corrupt_all(41, "signature is invalid")
1082
1083     def test_corrupt_all_k(self):
1084         # a corrupt 'k' will trigger a bad signature
1085         return self._test_corrupt_all(57, "signature is invalid")
1086
1087     def test_corrupt_all_N(self):
1088         # a corrupt 'N' will trigger a bad signature
1089         return self._test_corrupt_all(58, "signature is invalid")
1090
1091     def test_corrupt_all_segsize(self):
1092         # a corrupt segsize will trigger a bad signature
1093         return self._test_corrupt_all(59, "signature is invalid")
1094
1095     def test_corrupt_all_datalen(self):
1096         # a corrupt data length will trigger a bad signature
1097         return self._test_corrupt_all(67, "signature is invalid")
1098
1099     def test_corrupt_all_pubkey(self):
1100         # a corrupt pubkey won't match the URI's fingerprint. We need to
1101         # remove the pubkey from the filenode, or else it won't bother trying
1102         # to update it.
1103         self._fn._pubkey = None
1104         return self._test_corrupt_all("pubkey",
1105                                       "pubkey doesn't match fingerprint")
1106
1107     def test_corrupt_all_sig(self):
1108         # a corrupt signature is a bad one
1109         # the signature runs from about [543:799], depending upon the length
1110         # of the pubkey
1111         return self._test_corrupt_all("signature", "signature is invalid")
1112
1113     def test_corrupt_all_share_hash_chain_number(self):
1114         # a corrupt share hash chain entry will show up as a bad hash. If we
1115         # mangle the first byte, that will look like a bad hash number,
1116         # causing an IndexError
1117         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1118
1119     def test_corrupt_all_share_hash_chain_hash(self):
1120         # a corrupt share hash chain entry will show up as a bad hash. If we
1121         # mangle a few bytes in, that will look like a bad hash.
1122         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1123
1124     def test_corrupt_all_block_hash_tree(self):
1125         return self._test_corrupt_all("block_hash_tree",
1126                                       "block hash tree failure")
1127
1128     def test_corrupt_all_block(self):
1129         return self._test_corrupt_all("share_data", "block hash tree failure")
1130
1131     def test_corrupt_all_encprivkey(self):
1132         # a corrupted privkey won't even be noticed by the reader, only by a
1133         # writer.
1134         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1135
1136
1137     def test_corrupt_all_seqnum_late(self):
1138         # corrupting the seqnum between mapupdate and retrieve should result
1139         # in NotEnoughSharesError, since each share will look invalid
1140         def _check(res):
1141             f = res[0]
1142             self.failUnless(f.check(NotEnoughSharesError))
1143             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1144         return self._test_corrupt_all(1, "ran out of peers",
1145                                       corrupt_early=False,
1146                                       failure_checker=_check)
1147
1148     def test_corrupt_all_block_hash_tree_late(self):
1149         def _check(res):
1150             f = res[0]
1151             self.failUnless(f.check(NotEnoughSharesError))
1152         return self._test_corrupt_all("block_hash_tree",
1153                                       "block hash tree failure",
1154                                       corrupt_early=False,
1155                                       failure_checker=_check)
1156
1157
1158     def test_corrupt_all_block_late(self):
1159         def _check(res):
1160             f = res[0]
1161             self.failUnless(f.check(NotEnoughSharesError))
1162         return self._test_corrupt_all("share_data", "block hash tree failure",
1163                                       corrupt_early=False,
1164                                       failure_checker=_check)
1165
1166
1167     def test_basic_pubkey_at_end(self):
1168         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1169         # download to succeed but forcing a bunch of retries first. Note that
1170         # this is rather pessimistic: our Retrieve process will throw away
1171         # the whole share if the pubkey is bad, even though the rest of the
1172         # share might be good.
1173
1174         self._fn._pubkey = None
1175         k = self._fn.get_required_shares()
1176         N = self._fn.get_total_shares()
1177         d = defer.succeed(None)
1178         d.addCallback(corrupt, self._storage, "pubkey",
1179                       shnums_to_corrupt=range(0, N-k))
1180         d.addCallback(lambda res: self.make_servermap())
1181         def _do_retrieve(servermap):
1182             self.failUnless(servermap.problems)
1183             self.failUnless("pubkey doesn't match fingerprint"
1184                             in str(servermap.problems[0]))
1185             ver = servermap.best_recoverable_version()
1186             r = Retrieve(self._fn, servermap, ver)
1187             return r.download()
1188         d.addCallback(_do_retrieve)
1189         d.addCallback(lambda new_contents:
1190                       self.failUnlessEqual(new_contents, self.CONTENTS))
1191         return d
1192
1193     def test_corrupt_some(self):
1194         # corrupt the data of first five shares (so the servermap thinks
1195         # they're good but retrieve marks them as bad), so that the
1196         # MODE_READ set of 6 will be insufficient, forcing node.download to
1197         # retry with more servers.
1198         corrupt(None, self._storage, "share_data", range(5))
1199         d = self.make_servermap()
1200         def _do_retrieve(servermap):
1201             ver = servermap.best_recoverable_version()
1202             self.failUnless(ver)
1203             return self._fn.download_best_version()
1204         d.addCallback(_do_retrieve)
1205         d.addCallback(lambda new_contents:
1206                       self.failUnlessEqual(new_contents, self.CONTENTS))
1207         return d
1208
1209     def test_download_fails(self):
1210         corrupt(None, self._storage, "signature")
1211         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1212                             "no recoverable versions",
1213                             self._fn.download_best_version)
1214         return d
1215
1216
1217 class CheckerMixin:
1218     def check_good(self, r, where):
1219         self.failUnless(r.is_healthy(), where)
1220         return r
1221
1222     def check_bad(self, r, where):
1223         self.failIf(r.is_healthy(), where)
1224         return r
1225
1226     def check_expected_failure(self, r, expected_exception, substring, where):
1227         for (peerid, storage_index, shnum, f) in r.problems:
1228             if f.check(expected_exception):
1229                 self.failUnless(substring in str(f),
1230                                 "%s: substring '%s' not in '%s'" %
1231                                 (where, substring, str(f)))
1232                 return
1233         self.fail("%s: didn't see expected exception %s in problems %s" %
1234                   (where, expected_exception, r.problems))
1235
1236
1237 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1238     def setUp(self):
1239         return self.publish_one()
1240
1241
1242     def test_check_good(self):
1243         d = self._fn.check(Monitor())
1244         d.addCallback(self.check_good, "test_check_good")
1245         return d
1246
1247     def test_check_no_shares(self):
1248         for shares in self._storage._peers.values():
1249             shares.clear()
1250         d = self._fn.check(Monitor())
1251         d.addCallback(self.check_bad, "test_check_no_shares")
1252         return d
1253
1254     def test_check_not_enough_shares(self):
1255         for shares in self._storage._peers.values():
1256             for shnum in shares.keys():
1257                 if shnum > 0:
1258                     del shares[shnum]
1259         d = self._fn.check(Monitor())
1260         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1261         return d
1262
1263     def test_check_all_bad_sig(self):
1264         corrupt(None, self._storage, 1) # bad sig
1265         d = self._fn.check(Monitor())
1266         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1267         return d
1268
1269     def test_check_all_bad_blocks(self):
1270         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1271         # the Checker won't notice this.. it doesn't look at actual data
1272         d = self._fn.check(Monitor())
1273         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1274         return d
1275
1276     def test_verify_good(self):
1277         d = self._fn.check(Monitor(), verify=True)
1278         d.addCallback(self.check_good, "test_verify_good")
1279         return d
1280
1281     def test_verify_all_bad_sig(self):
1282         corrupt(None, self._storage, 1) # bad sig
1283         d = self._fn.check(Monitor(), verify=True)
1284         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1285         return d
1286
1287     def test_verify_one_bad_sig(self):
1288         corrupt(None, self._storage, 1, [9]) # bad sig
1289         d = self._fn.check(Monitor(), verify=True)
1290         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1291         return d
1292
1293     def test_verify_one_bad_block(self):
1294         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1295         # the Verifier *will* notice this, since it examines every byte
1296         d = self._fn.check(Monitor(), verify=True)
1297         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1298         d.addCallback(self.check_expected_failure,
1299                       CorruptShareError, "block hash tree failure",
1300                       "test_verify_one_bad_block")
1301         return d
1302
1303     def test_verify_one_bad_sharehash(self):
1304         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1305         d = self._fn.check(Monitor(), verify=True)
1306         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1307         d.addCallback(self.check_expected_failure,
1308                       CorruptShareError, "corrupt hashes",
1309                       "test_verify_one_bad_sharehash")
1310         return d
1311
1312     def test_verify_one_bad_encprivkey(self):
1313         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1314         d = self._fn.check(Monitor(), verify=True)
1315         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1316         d.addCallback(self.check_expected_failure,
1317                       CorruptShareError, "invalid privkey",
1318                       "test_verify_one_bad_encprivkey")
1319         return d
1320
1321     def test_verify_one_bad_encprivkey_uncheckable(self):
1322         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1323         readonly_fn = self._fn.get_readonly()
1324         # a read-only node has no way to validate the privkey
1325         d = readonly_fn.check(Monitor(), verify=True)
1326         d.addCallback(self.check_good,
1327                       "test_verify_one_bad_encprivkey_uncheckable")
1328         return d
1329
1330 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1331
1332     def get_shares(self, s):
1333         all_shares = {} # maps (peerid, shnum) to share data
1334         for peerid in s._peers:
1335             shares = s._peers[peerid]
1336             for shnum in shares:
1337                 data = shares[shnum]
1338                 all_shares[ (peerid, shnum) ] = data
1339         return all_shares
1340
1341     def copy_shares(self, ignored=None):
1342         self.old_shares.append(self.get_shares(self._storage))
1343
1344     def test_repair_nop(self):
1345         self.old_shares = []
1346         d = self.publish_one()
1347         d.addCallback(self.copy_shares)
1348         d.addCallback(lambda res: self._fn.check(Monitor()))
1349         d.addCallback(lambda check_results: self._fn.repair(check_results))
1350         def _check_results(rres):
1351             self.failUnless(IRepairResults.providedBy(rres))
1352             # TODO: examine results
1353
1354             self.copy_shares()
1355
1356             initial_shares = self.old_shares[0]
1357             new_shares = self.old_shares[1]
1358             # TODO: this really shouldn't change anything. When we implement
1359             # a "minimal-bandwidth" repairer", change this test to assert:
1360             #self.failUnlessEqual(new_shares, initial_shares)
1361
1362             # all shares should be in the same place as before
1363             self.failUnlessEqual(set(initial_shares.keys()),
1364                                  set(new_shares.keys()))
1365             # but they should all be at a newer seqnum. The IV will be
1366             # different, so the roothash will be too.
1367             for key in initial_shares:
1368                 (version0,
1369                  seqnum0,
1370                  root_hash0,
1371                  IV0,
1372                  k0, N0, segsize0, datalen0,
1373                  o0) = unpack_header(initial_shares[key])
1374                 (version1,
1375                  seqnum1,
1376                  root_hash1,
1377                  IV1,
1378                  k1, N1, segsize1, datalen1,
1379                  o1) = unpack_header(new_shares[key])
1380                 self.failUnlessEqual(version0, version1)
1381                 self.failUnlessEqual(seqnum0+1, seqnum1)
1382                 self.failUnlessEqual(k0, k1)
1383                 self.failUnlessEqual(N0, N1)
1384                 self.failUnlessEqual(segsize0, segsize1)
1385                 self.failUnlessEqual(datalen0, datalen1)
1386         d.addCallback(_check_results)
1387         return d
1388
1389     def failIfSharesChanged(self, ignored=None):
1390         old_shares = self.old_shares[-2]
1391         current_shares = self.old_shares[-1]
1392         self.failUnlessEqual(old_shares, current_shares)
1393
1394     def test_merge(self):
1395         self.old_shares = []
1396         d = self.publish_multiple()
1397         # repair will refuse to merge multiple highest seqnums unless you
1398         # pass force=True
1399         d.addCallback(lambda res:
1400                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1401                                           1:4,3:4,5:4,7:4,9:4}))
1402         d.addCallback(self.copy_shares)
1403         d.addCallback(lambda res: self._fn.check(Monitor()))
1404         def _try_repair(check_results):
1405             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1406             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1407                                  self._fn.repair, check_results)
1408             d2.addCallback(self.copy_shares)
1409             d2.addCallback(self.failIfSharesChanged)
1410             d2.addCallback(lambda res: check_results)
1411             return d2
1412         d.addCallback(_try_repair)
1413         d.addCallback(lambda check_results:
1414                       self._fn.repair(check_results, force=True))
1415         # this should give us 10 shares of the highest roothash
1416         def _check_repair_results(rres):
1417             pass # TODO
1418         d.addCallback(_check_repair_results)
1419         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1420         def _check_smap(smap):
1421             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1422             self.failIf(smap.unrecoverable_versions())
1423             # now, which should have won?
1424             roothash_s4a = self.get_roothash_for(3)
1425             roothash_s4b = self.get_roothash_for(4)
1426             if roothash_s4b > roothash_s4a:
1427                 expected_contents = self.CONTENTS[4]
1428             else:
1429                 expected_contents = self.CONTENTS[3]
1430             new_versionid = smap.best_recoverable_version()
1431             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1432             d2 = self._fn.download_version(smap, new_versionid)
1433             d2.addCallback(self.failUnlessEqual, expected_contents)
1434             return d2
1435         d.addCallback(_check_smap)
1436         return d
1437
1438     def test_non_merge(self):
1439         self.old_shares = []
1440         d = self.publish_multiple()
1441         # repair should not refuse a repair that doesn't need to merge. In
1442         # this case, we combine v2 with v3. The repair should ignore v2 and
1443         # copy v3 into a new v5.
1444         d.addCallback(lambda res:
1445                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1446                                           1:3,3:3,5:3,7:3,9:3}))
1447         d.addCallback(lambda res: self._fn.check(Monitor()))
1448         d.addCallback(lambda check_results: self._fn.repair(check_results))
1449         # this should give us 10 shares of v3
1450         def _check_repair_results(rres):
1451             pass # TODO
1452         d.addCallback(_check_repair_results)
1453         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1454         def _check_smap(smap):
1455             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1456             self.failIf(smap.unrecoverable_versions())
1457             # now, which should have won?
1458             roothash_s4a = self.get_roothash_for(3)
1459             expected_contents = self.CONTENTS[3]
1460             new_versionid = smap.best_recoverable_version()
1461             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1462             d2 = self._fn.download_version(smap, new_versionid)
1463             d2.addCallback(self.failUnlessEqual, expected_contents)
1464             return d2
1465         d.addCallback(_check_smap)
1466         return d
1467
1468     def get_roothash_for(self, index):
1469         # return the roothash for the first share we see in the saved set
1470         shares = self._copied_shares[index]
1471         for peerid in shares:
1472             for shnum in shares[peerid]:
1473                 share = shares[peerid][shnum]
1474                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1475                           unpack_header(share)
1476                 return root_hash
1477
1478 class MultipleEncodings(unittest.TestCase):
1479     def setUp(self):
1480         self.CONTENTS = "New contents go here"
1481         num_peers = 20
1482         self._client = FakeClient(num_peers)
1483         self._storage = self._client._storage
1484         d = self._client.create_mutable_file(self.CONTENTS)
1485         def _created(node):
1486             self._fn = node
1487         d.addCallback(_created)
1488         return d
1489
1490     def _encode(self, k, n, data):
1491         # encode 'data' into a peerid->shares dict.
1492
1493         fn2 = FastMutableFileNode(self._client)
1494         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1495         # and _fingerprint
1496         fn = self._fn
1497         fn2.init_from_uri(fn.get_uri())
1498         # then we copy over other fields that are normally fetched from the
1499         # existing shares
1500         fn2._pubkey = fn._pubkey
1501         fn2._privkey = fn._privkey
1502         fn2._encprivkey = fn._encprivkey
1503         # and set the encoding parameters to something completely different
1504         fn2._required_shares = k
1505         fn2._total_shares = n
1506
1507         s = self._client._storage
1508         s._peers = {} # clear existing storage
1509         p2 = Publish(fn2, None)
1510         d = p2.publish(data)
1511         def _published(res):
1512             shares = s._peers
1513             s._peers = {}
1514             return shares
1515         d.addCallback(_published)
1516         return d
1517
1518     def make_servermap(self, mode=MODE_READ, oldmap=None):
1519         if oldmap is None:
1520             oldmap = ServerMap()
1521         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1522         d = smu.update()
1523         return d
1524
1525     def test_multiple_encodings(self):
1526         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1527         # then mix up the shares, to make sure that download survives seeing
1528         # a variety of encodings. This is actually kind of tricky to set up.
1529
1530         contents1 = "Contents for encoding 1 (3-of-10) go here"
1531         contents2 = "Contents for encoding 2 (4-of-9) go here"
1532         contents3 = "Contents for encoding 3 (4-of-7) go here"
1533
1534         # we make a retrieval object that doesn't know what encoding
1535         # parameters to use
1536         fn3 = FastMutableFileNode(self._client)
1537         fn3.init_from_uri(self._fn.get_uri())
1538
1539         # now we upload a file through fn1, and grab its shares
1540         d = self._encode(3, 10, contents1)
1541         def _encoded_1(shares):
1542             self._shares1 = shares
1543         d.addCallback(_encoded_1)
1544         d.addCallback(lambda res: self._encode(4, 9, contents2))
1545         def _encoded_2(shares):
1546             self._shares2 = shares
1547         d.addCallback(_encoded_2)
1548         d.addCallback(lambda res: self._encode(4, 7, contents3))
1549         def _encoded_3(shares):
1550             self._shares3 = shares
1551         d.addCallback(_encoded_3)
1552
1553         def _merge(res):
1554             log.msg("merging sharelists")
1555             # we merge the shares from the two sets, leaving each shnum in
1556             # its original location, but using a share from set1 or set2
1557             # according to the following sequence:
1558             #
1559             #  4-of-9  a  s2
1560             #  4-of-9  b  s2
1561             #  4-of-7  c   s3
1562             #  4-of-9  d  s2
1563             #  3-of-9  e s1
1564             #  3-of-9  f s1
1565             #  3-of-9  g s1
1566             #  4-of-9  h  s2
1567             #
1568             # so that neither form can be recovered until fetch [f], at which
1569             # point version-s1 (the 3-of-10 form) should be recoverable. If
1570             # the implementation latches on to the first version it sees,
1571             # then s2 will be recoverable at fetch [g].
1572
1573             # Later, when we implement code that handles multiple versions,
1574             # we can use this framework to assert that all recoverable
1575             # versions are retrieved, and test that 'epsilon' does its job
1576
1577             places = [2, 2, 3, 2, 1, 1, 1, 2]
1578
1579             sharemap = {}
1580
1581             for i,peerid in enumerate(self._client._peerids):
1582                 peerid_s = shortnodeid_b2a(peerid)
1583                 for shnum in self._shares1.get(peerid, {}):
1584                     if shnum < len(places):
1585                         which = places[shnum]
1586                     else:
1587                         which = "x"
1588                     self._client._storage._peers[peerid] = peers = {}
1589                     in_1 = shnum in self._shares1[peerid]
1590                     in_2 = shnum in self._shares2.get(peerid, {})
1591                     in_3 = shnum in self._shares3.get(peerid, {})
1592                     #print peerid_s, shnum, which, in_1, in_2, in_3
1593                     if which == 1:
1594                         if in_1:
1595                             peers[shnum] = self._shares1[peerid][shnum]
1596                             sharemap[shnum] = peerid
1597                     elif which == 2:
1598                         if in_2:
1599                             peers[shnum] = self._shares2[peerid][shnum]
1600                             sharemap[shnum] = peerid
1601                     elif which == 3:
1602                         if in_3:
1603                             peers[shnum] = self._shares3[peerid][shnum]
1604                             sharemap[shnum] = peerid
1605
1606             # we don't bother placing any other shares
1607             # now sort the sequence so that share 0 is returned first
1608             new_sequence = [sharemap[shnum]
1609                             for shnum in sorted(sharemap.keys())]
1610             self._client._storage._sequence = new_sequence
1611             log.msg("merge done")
1612         d.addCallback(_merge)
1613         d.addCallback(lambda res: fn3.download_best_version())
1614         def _retrieved(new_contents):
1615             # the current specified behavior is "first version recoverable"
1616             self.failUnlessEqual(new_contents, contents1)
1617         d.addCallback(_retrieved)
1618         return d
1619
1620
1621 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1622
1623     def setUp(self):
1624         return self.publish_multiple()
1625
1626     def test_multiple_versions(self):
1627         # if we see a mix of versions in the grid, download_best_version
1628         # should get the latest one
1629         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1630         d = self._fn.download_best_version()
1631         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1632         # and the checker should report problems
1633         d.addCallback(lambda res: self._fn.check(Monitor()))
1634         d.addCallback(self.check_bad, "test_multiple_versions")
1635
1636         # but if everything is at version 2, that's what we should download
1637         d.addCallback(lambda res:
1638                       self._set_versions(dict([(i,2) for i in range(10)])))
1639         d.addCallback(lambda res: self._fn.download_best_version())
1640         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1641         # if exactly one share is at version 3, we should still get v2
1642         d.addCallback(lambda res:
1643                       self._set_versions({0:3}))
1644         d.addCallback(lambda res: self._fn.download_best_version())
1645         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1646         # but the servermap should see the unrecoverable version. This
1647         # depends upon the single newer share being queried early.
1648         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1649         def _check_smap(smap):
1650             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1651             newer = smap.unrecoverable_newer_versions()
1652             self.failUnlessEqual(len(newer), 1)
1653             verinfo, health = newer.items()[0]
1654             self.failUnlessEqual(verinfo[0], 4)
1655             self.failUnlessEqual(health, (1,3))
1656             self.failIf(smap.needs_merge())
1657         d.addCallback(_check_smap)
1658         # if we have a mix of two parallel versions (s4a and s4b), we could
1659         # recover either
1660         d.addCallback(lambda res:
1661                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1662                                           1:4,3:4,5:4,7:4,9:4}))
1663         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1664         def _check_smap_mixed(smap):
1665             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1666             newer = smap.unrecoverable_newer_versions()
1667             self.failUnlessEqual(len(newer), 0)
1668             self.failUnless(smap.needs_merge())
1669         d.addCallback(_check_smap_mixed)
1670         d.addCallback(lambda res: self._fn.download_best_version())
1671         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1672                                                   res == self.CONTENTS[4]))
1673         return d
1674
1675     def test_replace(self):
1676         # if we see a mix of versions in the grid, we should be able to
1677         # replace them all with a newer version
1678
1679         # if exactly one share is at version 3, we should download (and
1680         # replace) v2, and the result should be v4. Note that the index we
1681         # give to _set_versions is different than the sequence number.
1682         target = dict([(i,2) for i in range(10)]) # seqnum3
1683         target[0] = 3 # seqnum4
1684         self._set_versions(target)
1685
1686         def _modify(oldversion, servermap, first_time):
1687             return oldversion + " modified"
1688         d = self._fn.modify(_modify)
1689         d.addCallback(lambda res: self._fn.download_best_version())
1690         expected = self.CONTENTS[2] + " modified"
1691         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1692         # and the servermap should indicate that the outlier was replaced too
1693         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1694         def _check_smap(smap):
1695             self.failUnlessEqual(smap.highest_seqnum(), 5)
1696             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1697             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1698         d.addCallback(_check_smap)
1699         return d
1700
1701
1702 class Utils(unittest.TestCase):
1703     def test_dict_of_sets(self):
1704         ds = DictOfSets()
1705         ds.add(1, "a")
1706         ds.add(2, "b")
1707         ds.add(2, "b")
1708         ds.add(2, "c")
1709         self.failUnlessEqual(ds[1], set(["a"]))
1710         self.failUnlessEqual(ds[2], set(["b", "c"]))
1711         ds.discard(3, "d") # should not raise an exception
1712         ds.discard(2, "b")
1713         self.failUnlessEqual(ds[2], set(["c"]))
1714         ds.discard(2, "c")
1715         self.failIf(2 in ds)
1716
1717     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1718         # we compare this against sets of integers
1719         x = set(range(x_start, x_start+x_length))
1720         y = set(range(y_start, y_start+y_length))
1721         should_be_inside = x.issubset(y)
1722         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1723                                                          y_start, y_length),
1724                              str((x_start, x_length, y_start, y_length)))
1725
1726     def test_cache_inside(self):
1727         c = ResponseCache()
1728         x_start = 10
1729         x_length = 5
1730         for y_start in range(8, 17):
1731             for y_length in range(8):
1732                 self._do_inside(c, x_start, x_length, y_start, y_length)
1733
1734     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1735         # we compare this against sets of integers
1736         x = set(range(x_start, x_start+x_length))
1737         y = set(range(y_start, y_start+y_length))
1738         overlap = bool(x.intersection(y))
1739         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1740                                                       y_start, y_length),
1741                              str((x_start, x_length, y_start, y_length)))
1742
1743     def test_cache_overlap(self):
1744         c = ResponseCache()
1745         x_start = 10
1746         x_length = 5
1747         for y_start in range(8, 17):
1748             for y_length in range(8):
1749                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1750
1751     def test_cache(self):
1752         c = ResponseCache()
1753         # xdata = base62.b2a(os.urandom(100))[:100]
1754         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1755         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1756         nope = (None, None)
1757         c.add("v1", 1, 0, xdata, "time0")
1758         c.add("v1", 1, 2000, ydata, "time1")
1759         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1760         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1761         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1762         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1763         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1764         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1765         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1766         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1767         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1768         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1769         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1770         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1771         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1772         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1773         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1774         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1775         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1776         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1777
1778         # optional: join fragments
1779         c = ResponseCache()
1780         c.add("v1", 1, 0, xdata[:10], "time0")
1781         c.add("v1", 1, 10, xdata[10:20], "time1")
1782         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1783
1784 class Exceptions(unittest.TestCase):
1785     def test_repr(self):
1786         nmde = NeedMoreDataError(100, 50, 100)
1787         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1788         ucwe = UncoordinatedWriteError()
1789         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1790
1791 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1792 # instances which always succeed. So we need a less-fake one.
1793
1794 class IntentionalError(Exception):
1795     pass
1796
1797 class LocalWrapper:
1798     def __init__(self, original):
1799         self.original = original
1800         self.broken = False
1801         self.post_call_notifier = None
1802     def callRemote(self, methname, *args, **kwargs):
1803         def _call():
1804             if self.broken:
1805                 raise IntentionalError("I was asked to break")
1806             meth = getattr(self.original, "remote_" + methname)
1807             return meth(*args, **kwargs)
1808         d = fireEventually()
1809         d.addCallback(lambda res: _call())
1810         if self.post_call_notifier:
1811             d.addCallback(self.post_call_notifier, methname)
1812         return d
1813
1814 class LessFakeClient(FakeClient):
1815
1816     def __init__(self, basedir, num_peers=10):
1817         self._num_peers = num_peers
1818         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1819                          for i in range(self._num_peers)]
1820         self._connections = {}
1821         for peerid in self._peerids:
1822             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1823             make_dirs(peerdir)
1824             ss = storage.StorageServer(peerdir)
1825             ss.setNodeID(peerid)
1826             lw = LocalWrapper(ss)
1827             self._connections[peerid] = lw
1828         self.nodeid = "fakenodeid"
1829
1830
1831 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1832     def test_publish_surprise(self):
1833         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1834         self.client = LessFakeClient(basedir)
1835         d = self.client.create_mutable_file("contents 1")
1836         def _created(n):
1837             d = defer.succeed(None)
1838             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1839             def _got_smap1(smap):
1840                 # stash the old state of the file
1841                 self.old_map = smap
1842             d.addCallback(_got_smap1)
1843             # then modify the file, leaving the old map untouched
1844             d.addCallback(lambda res: log.msg("starting winning write"))
1845             d.addCallback(lambda res: n.overwrite("contents 2"))
1846             # now attempt to modify the file with the old servermap. This
1847             # will look just like an uncoordinated write, in which every
1848             # single share got updated between our mapupdate and our publish
1849             d.addCallback(lambda res: log.msg("starting doomed write"))
1850             d.addCallback(lambda res:
1851                           self.shouldFail(UncoordinatedWriteError,
1852                                           "test_publish_surprise", None,
1853                                           n.upload,
1854                                           "contents 2a", self.old_map))
1855             return d
1856         d.addCallback(_created)
1857         return d
1858
1859     def test_retrieve_surprise(self):
1860         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1861         self.client = LessFakeClient(basedir)
1862         d = self.client.create_mutable_file("contents 1")
1863         def _created(n):
1864             d = defer.succeed(None)
1865             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1866             def _got_smap1(smap):
1867                 # stash the old state of the file
1868                 self.old_map = smap
1869             d.addCallback(_got_smap1)
1870             # then modify the file, leaving the old map untouched
1871             d.addCallback(lambda res: log.msg("starting winning write"))
1872             d.addCallback(lambda res: n.overwrite("contents 2"))
1873             # now attempt to retrieve the old version with the old servermap.
1874             # This will look like someone has changed the file since we
1875             # updated the servermap.
1876             d.addCallback(lambda res: n._cache._clear())
1877             d.addCallback(lambda res: log.msg("starting doomed read"))
1878             d.addCallback(lambda res:
1879                           self.shouldFail(NotEnoughSharesError,
1880                                           "test_retrieve_surprise",
1881                                           "ran out of peers: have 0 shares (k=3)",
1882                                           n.download_version,
1883                                           self.old_map,
1884                                           self.old_map.best_recoverable_version(),
1885                                           ))
1886             return d
1887         d.addCallback(_created)
1888         return d
1889
1890     def test_unexpected_shares(self):
1891         # upload the file, take a servermap, shut down one of the servers,
1892         # upload it again (causing shares to appear on a new server), then
1893         # upload using the old servermap. The last upload should fail with an
1894         # UncoordinatedWriteError, because of the shares that didn't appear
1895         # in the servermap.
1896         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1897         self.client = LessFakeClient(basedir)
1898         d = self.client.create_mutable_file("contents 1")
1899         def _created(n):
1900             d = defer.succeed(None)
1901             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1902             def _got_smap1(smap):
1903                 # stash the old state of the file
1904                 self.old_map = smap
1905                 # now shut down one of the servers
1906                 peer0 = list(smap.make_sharemap()[0])[0]
1907                 self.client._connections.pop(peer0)
1908                 # then modify the file, leaving the old map untouched
1909                 log.msg("starting winning write")
1910                 return n.overwrite("contents 2")
1911             d.addCallback(_got_smap1)
1912             # now attempt to modify the file with the old servermap. This
1913             # will look just like an uncoordinated write, in which every
1914             # single share got updated between our mapupdate and our publish
1915             d.addCallback(lambda res: log.msg("starting doomed write"))
1916             d.addCallback(lambda res:
1917                           self.shouldFail(UncoordinatedWriteError,
1918                                           "test_surprise", None,
1919                                           n.upload,
1920                                           "contents 2a", self.old_map))
1921             return d
1922         d.addCallback(_created)
1923         return d
1924
1925     def test_bad_server(self):
1926         # Break one server, then create the file: the initial publish should
1927         # complete with an alternate server. Breaking a second server should
1928         # not prevent an update from succeeding either.
1929         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1930         self.client = LessFakeClient(basedir, 20)
1931         # to make sure that one of the initial peers is broken, we have to
1932         # get creative. We create the keys, so we can figure out the storage
1933         # index, but we hold off on doing the initial publish until we've
1934         # broken the server on which the first share wants to be stored.
1935         n = FastMutableFileNode(self.client)
1936         d = defer.succeed(None)
1937         d.addCallback(n._generate_pubprivkeys)
1938         d.addCallback(n._generated)
1939         def _break_peer0(res):
1940             si = n.get_storage_index()
1941             peerlist = self.client.get_permuted_peers("storage", si)
1942             peerid0, connection0 = peerlist[0]
1943             peerid1, connection1 = peerlist[1]
1944             connection0.broken = True
1945             self.connection1 = connection1
1946         d.addCallback(_break_peer0)
1947         # now let the initial publish finally happen
1948         d.addCallback(lambda res: n._upload("contents 1", None))
1949         # that ought to work
1950         d.addCallback(lambda res: n.download_best_version())
1951         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1952         # now break the second peer
1953         def _break_peer1(res):
1954             self.connection1.broken = True
1955         d.addCallback(_break_peer1)
1956         d.addCallback(lambda res: n.overwrite("contents 2"))
1957         # that ought to work too
1958         d.addCallback(lambda res: n.download_best_version())
1959         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1960         return d
1961
1962     def test_bad_server_overlap(self):
1963         # like test_bad_server, but with no extra unused servers to fall back
1964         # upon. This means that we must re-use a server which we've already
1965         # used. If we don't remember the fact that we sent them one share
1966         # already, we'll mistakenly think we're experiencing an
1967         # UncoordinatedWriteError.
1968
1969         # Break one server, then create the file: the initial publish should
1970         # complete with an alternate server. Breaking a second server should
1971         # not prevent an update from succeeding either.
1972         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1973         self.client = LessFakeClient(basedir, 10)
1974
1975         peerids = sorted(self.client._connections.keys())
1976         self.client._connections[peerids[0]].broken = True
1977
1978         d = self.client.create_mutable_file("contents 1")
1979         def _created(n):
1980             d = n.download_best_version()
1981             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1982             # now break one of the remaining servers
1983             def _break_second_server(res):
1984                 self.client._connections[peerids[1]].broken = True
1985             d.addCallback(_break_second_server)
1986             d.addCallback(lambda res: n.overwrite("contents 2"))
1987             # that ought to work too
1988             d.addCallback(lambda res: n.download_best_version())
1989             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1990             return d
1991         d.addCallback(_created)
1992         return d
1993
1994     def test_publish_all_servers_bad(self):
1995         # Break all servers: the publish should fail
1996         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1997         self.client = LessFakeClient(basedir, 20)
1998         for connection in self.client._connections.values():
1999             connection.broken = True
2000         d = self.shouldFail(NotEnoughServersError,
2001                             "test_publish_all_servers_bad",
2002                             "Ran out of non-bad servers",
2003                             self.client.create_mutable_file, "contents")
2004         return d
2005
2006     def test_publish_no_servers(self):
2007         # no servers at all: the publish should fail
2008         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
2009         self.client = LessFakeClient(basedir, 0)
2010         d = self.shouldFail(NotEnoughServersError,
2011                             "test_publish_no_servers",
2012                             "Ran out of non-bad servers",
2013                             self.client.create_mutable_file, "contents")
2014         return d
2015     test_publish_no_servers.timeout = 30
2016
2017
2018     def test_privkey_query_error(self):
2019         # when a servermap is updated with MODE_WRITE, it tries to get the
2020         # privkey. Something might go wrong during this query attempt.
2021         self.client = FakeClient(20)
2022         # we need some contents that are large enough to push the privkey out
2023         # of the early part of the file
2024         LARGE = "These are Larger contents" * 200 # about 5KB
2025         d = self.client.create_mutable_file(LARGE)
2026         def _created(n):
2027             self.uri = n.get_uri()
2028             self.n2 = self.client.create_node_from_uri(self.uri)
2029             # we start by doing a map update to figure out which is the first
2030             # server.
2031             return n.get_servermap(MODE_WRITE)
2032         d.addCallback(_created)
2033         d.addCallback(lambda res: fireEventually(res))
2034         def _got_smap1(smap):
2035             peer0 = list(smap.make_sharemap()[0])[0]
2036             # we tell the server to respond to this peer first, so that it
2037             # will be asked for the privkey first
2038             self.client._storage._sequence = [peer0]
2039             # now we make the peer fail their second query
2040             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2041         d.addCallback(_got_smap1)
2042         # now we update a servermap from a new node (which doesn't have the
2043         # privkey yet, forcing it to use a separate privkey query). Each
2044         # query response will trigger a privkey query, and since we're using
2045         # _sequence to make the peer0 response come back first, we'll send it
2046         # a privkey query first, and _sequence will again ensure that the
2047         # peer0 query will also come back before the others, and then
2048         # _special_answers will make sure that the query raises an exception.
2049         # The whole point of these hijinks is to exercise the code in
2050         # _privkey_query_failed. Note that the map-update will succeed, since
2051         # we'll just get a copy from one of the other shares.
2052         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2053         # Using FakeStorage._sequence means there will be read requests still
2054         # floating around.. wait for them to retire
2055         def _cancel_timer(res):
2056             if self.client._storage._pending_timer:
2057                 self.client._storage._pending_timer.cancel()
2058             return res
2059         d.addBoth(_cancel_timer)
2060         return d
2061
2062     def test_privkey_query_missing(self):
2063         # like test_privkey_query_error, but the shares are deleted by the
2064         # second query, instead of raising an exception.
2065         self.client = FakeClient(20)
2066         LARGE = "These are Larger contents" * 200 # about 5KB
2067         d = self.client.create_mutable_file(LARGE)
2068         def _created(n):
2069             self.uri = n.get_uri()
2070             self.n2 = self.client.create_node_from_uri(self.uri)
2071             return n.get_servermap(MODE_WRITE)
2072         d.addCallback(_created)
2073         d.addCallback(lambda res: fireEventually(res))
2074         def _got_smap1(smap):
2075             peer0 = list(smap.make_sharemap()[0])[0]
2076             self.client._storage._sequence = [peer0]
2077             self.client._storage._special_answers[peer0] = ["normal", "none"]
2078         d.addCallback(_got_smap1)
2079         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2080         def _cancel_timer(res):
2081             if self.client._storage._pending_timer:
2082                 self.client._storage._pending_timer.cancel()
2083             return res
2084         d.addBoth(_cancel_timer)
2085         return d