]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
change StorageServer to take nodeid in the constructor, instead of assigning it later...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri
8 from allmydata.storage.server import StorageServer
9 from allmydata.immutable import download
10 from allmydata.util import base32, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      FileTooLargeError, NotEnoughSharesError, IRepairResults
16 from allmydata.monitor import Monitor
17 from allmydata.test.common import ShouldFailMixin
18 from foolscap.eventual import eventually, fireEventually
19 from foolscap.logging import log
20 import sha
21
22 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
23 from allmydata.mutable.common import ResponseCache, \
24      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
25      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
26      NotEnoughServersError, CorruptShareError
27 from allmydata.mutable.retrieve import Retrieve
28 from allmydata.mutable.publish import Publish
29 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
30 from allmydata.mutable.layout import unpack_header, unpack_share
31 from allmydata.mutable.repairer import MustForceRepairError
32
33 import common_util as testutil
34
35 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
36 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
37 # of this.
38
39 class FastMutableFileNode(MutableFileNode):
40     SIGNATURE_KEY_SIZE = 522
41
42 # this "FakeStorage" exists to put the share data in RAM and avoid using real
43 # network connections, both to speed up the tests and to reduce the amount of
44 # non-mutable.py code being exercised.
45
46 class FakeStorage:
47     # this class replaces the collection of storage servers, allowing the
48     # tests to examine and manipulate the published shares. It also lets us
49     # control the order in which read queries are answered, to exercise more
50     # of the error-handling code in Retrieve .
51     #
52     # Note that we ignore the storage index: this FakeStorage instance can
53     # only be used for a single storage index.
54
55
56     def __init__(self):
57         self._peers = {}
58         # _sequence is used to cause the responses to occur in a specific
59         # order. If it is in use, then we will defer queries instead of
60         # answering them right away, accumulating the Deferreds in a dict. We
61         # don't know exactly how many queries we'll get, so exactly one
62         # second after the first query arrives, we will release them all (in
63         # order).
64         self._sequence = None
65         self._pending = {}
66         self._pending_timer = None
67         self._special_answers = {}
68
69     def read(self, peerid, storage_index):
70         shares = self._peers.get(peerid, {})
71         if self._special_answers.get(peerid, []):
72             mode = self._special_answers[peerid].pop(0)
73             if mode == "fail":
74                 shares = failure.Failure(IntentionalError())
75             elif mode == "none":
76                 shares = {}
77             elif mode == "normal":
78                 pass
79         if self._sequence is None:
80             return defer.succeed(shares)
81         d = defer.Deferred()
82         if not self._pending:
83             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
84         self._pending[peerid] = (d, shares)
85         return d
86
87     def _fire_readers(self):
88         self._pending_timer = None
89         pending = self._pending
90         self._pending = {}
91         extra = []
92         for peerid in self._sequence:
93             if peerid in pending:
94                 d, shares = pending.pop(peerid)
95                 eventually(d.callback, shares)
96         for (d, shares) in pending.values():
97             eventually(d.callback, shares)
98
99     def write(self, peerid, storage_index, shnum, offset, data):
100         if peerid not in self._peers:
101             self._peers[peerid] = {}
102         shares = self._peers[peerid]
103         f = StringIO()
104         f.write(shares.get(shnum, ""))
105         f.seek(offset)
106         f.write(data)
107         shares[shnum] = f.getvalue()
108
109
110 class FakeStorageServer:
111     def __init__(self, peerid, storage):
112         self.peerid = peerid
113         self.storage = storage
114         self.queries = 0
115     def callRemote(self, methname, *args, **kwargs):
116         def _call():
117             meth = getattr(self, methname)
118             return meth(*args, **kwargs)
119         d = fireEventually()
120         d.addCallback(lambda res: _call())
121         return d
122     def callRemoteOnly(self, methname, *args, **kwargs):
123         d = self.callRemote(methname, *args, **kwargs)
124         d.addBoth(lambda ignore: None)
125         pass
126
127     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
128         pass
129
130     def slot_readv(self, storage_index, shnums, readv):
131         d = self.storage.read(self.peerid, storage_index)
132         def _read(shares):
133             response = {}
134             for shnum in shares:
135                 if shnums and shnum not in shnums:
136                     continue
137                 vector = response[shnum] = []
138                 for (offset, length) in readv:
139                     assert isinstance(offset, (int, long)), offset
140                     assert isinstance(length, (int, long)), length
141                     vector.append(shares[shnum][offset:offset+length])
142             return response
143         d.addCallback(_read)
144         return d
145
146     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
147                                         tw_vectors, read_vector):
148         # always-pass: parrot the test vectors back to them.
149         readv = {}
150         for shnum, (testv, writev, new_length) in tw_vectors.items():
151             for (offset, length, op, specimen) in testv:
152                 assert op in ("le", "eq", "ge")
153             # TODO: this isn't right, the read is controlled by read_vector,
154             # not by testv
155             readv[shnum] = [ specimen
156                              for (offset, length, op, specimen)
157                              in testv ]
158             for (offset, data) in writev:
159                 self.storage.write(self.peerid, storage_index, shnum,
160                                    offset, data)
161         answer = (True, readv)
162         return fireEventually(answer)
163
164
165 # our "FakeClient" has just enough functionality of the real Client to let
166 # the tests run.
167
168 class FakeClient:
169     mutable_file_node_class = FastMutableFileNode
170
171     def __init__(self, num_peers=10):
172         self._storage = FakeStorage()
173         self._num_peers = num_peers
174         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
175                          for i in range(self._num_peers)]
176         self._connections = dict([(peerid, FakeStorageServer(peerid,
177                                                              self._storage))
178                                   for peerid in self._peerids])
179         self.nodeid = "fakenodeid"
180
181     def get_encoding_parameters(self):
182         return {"k": 3, "n": 10}
183
184     def log(self, msg, **kw):
185         return log.msg(msg, **kw)
186
187     def get_renewal_secret(self):
188         return "I hereby permit you to renew my files"
189     def get_cancel_secret(self):
190         return "I hereby permit you to cancel my leases"
191
192     def create_mutable_file(self, contents=""):
193         n = self.mutable_file_node_class(self)
194         d = n.create(contents)
195         d.addCallback(lambda res: n)
196         return d
197
198     def get_history(self):
199         return None
200
201     def create_node_from_uri(self, u):
202         u = IURI(u)
203         assert IMutableFileURI.providedBy(u), u
204         res = self.mutable_file_node_class(self).init_from_uri(u)
205         return res
206
207     def get_permuted_peers(self, service_name, key):
208         """
209         @return: list of (peerid, connection,)
210         """
211         results = []
212         for (peerid, connection) in self._connections.items():
213             assert isinstance(peerid, str)
214             permuted = sha.new(key + peerid).digest()
215             results.append((permuted, peerid, connection))
216         results.sort()
217         results = [ (r[1],r[2]) for r in results]
218         return results
219
220     def upload(self, uploadable):
221         assert IUploadable.providedBy(uploadable)
222         d = uploadable.get_size()
223         d.addCallback(lambda length: uploadable.read(length))
224         #d.addCallback(self.create_mutable_file)
225         def _got_data(datav):
226             data = "".join(datav)
227             #newnode = FastMutableFileNode(self)
228             return uri.LiteralFileURI(data)
229         d.addCallback(_got_data)
230         return d
231
232
233 def flip_bit(original, byte_offset):
234     return (original[:byte_offset] +
235             chr(ord(original[byte_offset]) ^ 0x01) +
236             original[byte_offset+1:])
237
238 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
239     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
240     # list of shnums to corrupt.
241     for peerid in s._peers:
242         shares = s._peers[peerid]
243         for shnum in shares:
244             if (shnums_to_corrupt is not None
245                 and shnum not in shnums_to_corrupt):
246                 continue
247             data = shares[shnum]
248             (version,
249              seqnum,
250              root_hash,
251              IV,
252              k, N, segsize, datalen,
253              o) = unpack_header(data)
254             if isinstance(offset, tuple):
255                 offset1, offset2 = offset
256             else:
257                 offset1 = offset
258                 offset2 = 0
259             if offset1 == "pubkey":
260                 real_offset = 107
261             elif offset1 in o:
262                 real_offset = o[offset1]
263             else:
264                 real_offset = offset1
265             real_offset = int(real_offset) + offset2 + offset_offset
266             assert isinstance(real_offset, int), offset
267             shares[shnum] = flip_bit(data, real_offset)
268     return res
269
270 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
271     def setUp(self):
272         self.client = FakeClient()
273
274     def test_create(self):
275         d = self.client.create_mutable_file()
276         def _created(n):
277             self.failUnless(isinstance(n, FastMutableFileNode))
278             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
279             peer0 = self.client._peerids[0]
280             shnums = self.client._storage._peers[peer0].keys()
281             self.failUnlessEqual(len(shnums), 1)
282         d.addCallback(_created)
283         return d
284
285     def test_serialize(self):
286         n = MutableFileNode(self.client)
287         calls = []
288         def _callback(*args, **kwargs):
289             self.failUnlessEqual(args, (4,) )
290             self.failUnlessEqual(kwargs, {"foo": 5})
291             calls.append(1)
292             return 6
293         d = n._do_serialized(_callback, 4, foo=5)
294         def _check_callback(res):
295             self.failUnlessEqual(res, 6)
296             self.failUnlessEqual(calls, [1])
297         d.addCallback(_check_callback)
298
299         def _errback():
300             raise ValueError("heya")
301         d.addCallback(lambda res:
302                       self.shouldFail(ValueError, "_check_errback", "heya",
303                                       n._do_serialized, _errback))
304         return d
305
306     def test_upload_and_download(self):
307         d = self.client.create_mutable_file()
308         def _created(n):
309             d = defer.succeed(None)
310             d.addCallback(lambda res: n.get_servermap(MODE_READ))
311             d.addCallback(lambda smap: smap.dump(StringIO()))
312             d.addCallback(lambda sio:
313                           self.failUnless("3-of-10" in sio.getvalue()))
314             d.addCallback(lambda res: n.overwrite("contents 1"))
315             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
316             d.addCallback(lambda res: n.download_best_version())
317             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
318             d.addCallback(lambda res: n.get_size_of_best_version())
319             d.addCallback(lambda size:
320                           self.failUnlessEqual(size, len("contents 1")))
321             d.addCallback(lambda res: n.overwrite("contents 2"))
322             d.addCallback(lambda res: n.download_best_version())
323             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
324             d.addCallback(lambda res: n.download(download.Data()))
325             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
326             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
327             d.addCallback(lambda smap: n.upload("contents 3", smap))
328             d.addCallback(lambda res: n.download_best_version())
329             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
330             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
331             d.addCallback(lambda smap:
332                           n.download_version(smap,
333                                              smap.best_recoverable_version()))
334             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
335             # test a file that is large enough to overcome the
336             # mapupdate-to-retrieve data caching (i.e. make the shares larger
337             # than the default readsize, which is 2000 bytes). A 15kB file
338             # will have 5kB shares.
339             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
340             d.addCallback(lambda res: n.download_best_version())
341             d.addCallback(lambda res:
342                           self.failUnlessEqual(res, "large size file" * 1000))
343             return d
344         d.addCallback(_created)
345         return d
346
347     def test_create_with_initial_contents(self):
348         d = self.client.create_mutable_file("contents 1")
349         def _created(n):
350             d = n.download_best_version()
351             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
352             d.addCallback(lambda res: n.overwrite("contents 2"))
353             d.addCallback(lambda res: n.download_best_version())
354             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
355             return d
356         d.addCallback(_created)
357         return d
358
359     def test_create_with_too_large_contents(self):
360         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
361         d = self.shouldFail(FileTooLargeError, "too_large",
362                             "SDMF is limited to one segment, and %d > %d" %
363                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
364                             self.client.create_mutable_file, BIG)
365         d.addCallback(lambda res: self.client.create_mutable_file("small"))
366         def _created(n):
367             return self.shouldFail(FileTooLargeError, "too_large_2",
368                                    "SDMF is limited to one segment, and %d > %d" %
369                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
370                                    n.overwrite, BIG)
371         d.addCallback(_created)
372         return d
373
374     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
375         d = n.get_servermap(MODE_READ)
376         d.addCallback(lambda servermap: servermap.best_recoverable_version())
377         d.addCallback(lambda verinfo:
378                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
379         return d
380
381     def test_modify(self):
382         def _modifier(old_contents, servermap, first_time):
383             return old_contents + "line2"
384         def _non_modifier(old_contents, servermap, first_time):
385             return old_contents
386         def _none_modifier(old_contents, servermap, first_time):
387             return None
388         def _error_modifier(old_contents, servermap, first_time):
389             raise ValueError("oops")
390         def _toobig_modifier(old_contents, servermap, first_time):
391             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
392         calls = []
393         def _ucw_error_modifier(old_contents, servermap, first_time):
394             # simulate an UncoordinatedWriteError once
395             calls.append(1)
396             if len(calls) <= 1:
397                 raise UncoordinatedWriteError("simulated")
398             return old_contents + "line3"
399         def _ucw_error_non_modifier(old_contents, servermap, first_time):
400             # simulate an UncoordinatedWriteError once, and don't actually
401             # modify the contents on subsequent invocations
402             calls.append(1)
403             if len(calls) <= 1:
404                 raise UncoordinatedWriteError("simulated")
405             return old_contents
406
407         d = self.client.create_mutable_file("line1")
408         def _created(n):
409             d = n.modify(_modifier)
410             d.addCallback(lambda res: n.download_best_version())
411             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
412             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
413
414             d.addCallback(lambda res: n.modify(_non_modifier))
415             d.addCallback(lambda res: n.download_best_version())
416             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
417             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
418
419             d.addCallback(lambda res: n.modify(_none_modifier))
420             d.addCallback(lambda res: n.download_best_version())
421             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
422             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
423
424             d.addCallback(lambda res:
425                           self.shouldFail(ValueError, "error_modifier", None,
426                                           n.modify, _error_modifier))
427             d.addCallback(lambda res: n.download_best_version())
428             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
429             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
430
431             d.addCallback(lambda res:
432                           self.shouldFail(FileTooLargeError, "toobig_modifier",
433                                           "SDMF is limited to one segment",
434                                           n.modify, _toobig_modifier))
435             d.addCallback(lambda res: n.download_best_version())
436             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
437             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
438
439             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
440             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
441             d.addCallback(lambda res: n.download_best_version())
442             d.addCallback(lambda res: self.failUnlessEqual(res,
443                                                            "line1line2line3"))
444             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
445
446             def _reset_ucw_error_modifier(res):
447                 calls[:] = []
448                 return res
449             d.addCallback(_reset_ucw_error_modifier)
450
451             # in practice, this n.modify call should publish twice: the first
452             # one gets a UCWE, the second does not. But our test jig (in
453             # which the modifier raises the UCWE) skips over the first one,
454             # so in this test there will be only one publish, and the seqnum
455             # will only be one larger than the previous test, not two (i.e. 4
456             # instead of 5).
457             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
458             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
459             d.addCallback(lambda res: n.download_best_version())
460             d.addCallback(lambda res: self.failUnlessEqual(res,
461                                                            "line1line2line3"))
462             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
463
464             return d
465         d.addCallback(_created)
466         return d
467
468     def test_modify_backoffer(self):
469         def _modifier(old_contents, servermap, first_time):
470             return old_contents + "line2"
471         calls = []
472         def _ucw_error_modifier(old_contents, servermap, first_time):
473             # simulate an UncoordinatedWriteError once
474             calls.append(1)
475             if len(calls) <= 1:
476                 raise UncoordinatedWriteError("simulated")
477             return old_contents + "line3"
478         def _always_ucw_error_modifier(old_contents, servermap, first_time):
479             raise UncoordinatedWriteError("simulated")
480         def _backoff_stopper(node, f):
481             return f
482         def _backoff_pauser(node, f):
483             d = defer.Deferred()
484             reactor.callLater(0.5, d.callback, None)
485             return d
486
487         # the give-up-er will hit its maximum retry count quickly
488         giveuper = BackoffAgent()
489         giveuper._delay = 0.1
490         giveuper.factor = 1
491
492         d = self.client.create_mutable_file("line1")
493         def _created(n):
494             d = n.modify(_modifier)
495             d.addCallback(lambda res: n.download_best_version())
496             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
497             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
498
499             d.addCallback(lambda res:
500                           self.shouldFail(UncoordinatedWriteError,
501                                           "_backoff_stopper", None,
502                                           n.modify, _ucw_error_modifier,
503                                           _backoff_stopper))
504             d.addCallback(lambda res: n.download_best_version())
505             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
506             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
507
508             def _reset_ucw_error_modifier(res):
509                 calls[:] = []
510                 return res
511             d.addCallback(_reset_ucw_error_modifier)
512             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
513                                                _backoff_pauser))
514             d.addCallback(lambda res: n.download_best_version())
515             d.addCallback(lambda res: self.failUnlessEqual(res,
516                                                            "line1line2line3"))
517             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
518
519             d.addCallback(lambda res:
520                           self.shouldFail(UncoordinatedWriteError,
521                                           "giveuper", None,
522                                           n.modify, _always_ucw_error_modifier,
523                                           giveuper.delay))
524             d.addCallback(lambda res: n.download_best_version())
525             d.addCallback(lambda res: self.failUnlessEqual(res,
526                                                            "line1line2line3"))
527             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
528
529             return d
530         d.addCallback(_created)
531         return d
532
533     def test_upload_and_download_full_size_keys(self):
534         self.client.mutable_file_node_class = MutableFileNode
535         d = self.client.create_mutable_file()
536         def _created(n):
537             d = defer.succeed(None)
538             d.addCallback(lambda res: n.get_servermap(MODE_READ))
539             d.addCallback(lambda smap: smap.dump(StringIO()))
540             d.addCallback(lambda sio:
541                           self.failUnless("3-of-10" in sio.getvalue()))
542             d.addCallback(lambda res: n.overwrite("contents 1"))
543             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
544             d.addCallback(lambda res: n.download_best_version())
545             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
546             d.addCallback(lambda res: n.overwrite("contents 2"))
547             d.addCallback(lambda res: n.download_best_version())
548             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
549             d.addCallback(lambda res: n.download(download.Data()))
550             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
551             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
552             d.addCallback(lambda smap: n.upload("contents 3", smap))
553             d.addCallback(lambda res: n.download_best_version())
554             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
555             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
556             d.addCallback(lambda smap:
557                           n.download_version(smap,
558                                              smap.best_recoverable_version()))
559             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
560             return d
561         d.addCallback(_created)
562         return d
563
564
565 class MakeShares(unittest.TestCase):
566     def test_encrypt(self):
567         c = FakeClient()
568         fn = FastMutableFileNode(c)
569         CONTENTS = "some initial contents"
570         d = fn.create(CONTENTS)
571         def _created(res):
572             p = Publish(fn, None)
573             p.salt = "SALT" * 4
574             p.readkey = "\x00" * 16
575             p.newdata = CONTENTS
576             p.required_shares = 3
577             p.total_shares = 10
578             p.setup_encoding_parameters()
579             return p._encrypt_and_encode()
580         d.addCallback(_created)
581         def _done(shares_and_shareids):
582             (shares, share_ids) = shares_and_shareids
583             self.failUnlessEqual(len(shares), 10)
584             for sh in shares:
585                 self.failUnless(isinstance(sh, str))
586                 self.failUnlessEqual(len(sh), 7)
587             self.failUnlessEqual(len(share_ids), 10)
588         d.addCallback(_done)
589         return d
590
591     def test_generate(self):
592         c = FakeClient()
593         fn = FastMutableFileNode(c)
594         CONTENTS = "some initial contents"
595         d = fn.create(CONTENTS)
596         def _created(res):
597             p = Publish(fn, None)
598             self._p = p
599             p.newdata = CONTENTS
600             p.required_shares = 3
601             p.total_shares = 10
602             p.setup_encoding_parameters()
603             p._new_seqnum = 3
604             p.salt = "SALT" * 4
605             # make some fake shares
606             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
607             p._privkey = fn.get_privkey()
608             p._encprivkey = fn.get_encprivkey()
609             p._pubkey = fn.get_pubkey()
610             return p._generate_shares(shares_and_ids)
611         d.addCallback(_created)
612         def _generated(res):
613             p = self._p
614             final_shares = p.shares
615             root_hash = p.root_hash
616             self.failUnlessEqual(len(root_hash), 32)
617             self.failUnless(isinstance(final_shares, dict))
618             self.failUnlessEqual(len(final_shares), 10)
619             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
620             for i,sh in final_shares.items():
621                 self.failUnless(isinstance(sh, str))
622                 # feed the share through the unpacker as a sanity-check
623                 pieces = unpack_share(sh)
624                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
625                  pubkey, signature, share_hash_chain, block_hash_tree,
626                  share_data, enc_privkey) = pieces
627                 self.failUnlessEqual(u_seqnum, 3)
628                 self.failUnlessEqual(u_root_hash, root_hash)
629                 self.failUnlessEqual(k, 3)
630                 self.failUnlessEqual(N, 10)
631                 self.failUnlessEqual(segsize, 21)
632                 self.failUnlessEqual(datalen, len(CONTENTS))
633                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
634                 sig_material = struct.pack(">BQ32s16s BBQQ",
635                                            0, p._new_seqnum, root_hash, IV,
636                                            k, N, segsize, datalen)
637                 self.failUnless(p._pubkey.verify(sig_material, signature))
638                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
639                 self.failUnless(isinstance(share_hash_chain, dict))
640                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
641                 for shnum,share_hash in share_hash_chain.items():
642                     self.failUnless(isinstance(shnum, int))
643                     self.failUnless(isinstance(share_hash, str))
644                     self.failUnlessEqual(len(share_hash), 32)
645                 self.failUnless(isinstance(block_hash_tree, list))
646                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
647                 self.failUnlessEqual(IV, "SALT"*4)
648                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
649                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
650         d.addCallback(_generated)
651         return d
652
653     # TODO: when we publish to 20 peers, we should get one share per peer on 10
654     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
655     # when we publish to zero peers, we should get a NotEnoughSharesError
656
657 class PublishMixin:
658     def publish_one(self):
659         # publish a file and create shares, which can then be manipulated
660         # later.
661         self.CONTENTS = "New contents go here" * 1000
662         num_peers = 20
663         self._client = FakeClient(num_peers)
664         self._storage = self._client._storage
665         d = self._client.create_mutable_file(self.CONTENTS)
666         def _created(node):
667             self._fn = node
668             self._fn2 = self._client.create_node_from_uri(node.get_uri())
669         d.addCallback(_created)
670         return d
671     def publish_multiple(self):
672         self.CONTENTS = ["Contents 0",
673                          "Contents 1",
674                          "Contents 2",
675                          "Contents 3a",
676                          "Contents 3b"]
677         self._copied_shares = {}
678         num_peers = 20
679         self._client = FakeClient(num_peers)
680         self._storage = self._client._storage
681         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
682         def _created(node):
683             self._fn = node
684             # now create multiple versions of the same file, and accumulate
685             # their shares, so we can mix and match them later.
686             d = defer.succeed(None)
687             d.addCallback(self._copy_shares, 0)
688             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
689             d.addCallback(self._copy_shares, 1)
690             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
691             d.addCallback(self._copy_shares, 2)
692             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
693             d.addCallback(self._copy_shares, 3)
694             # now we replace all the shares with version s3, and upload a new
695             # version to get s4b.
696             rollback = dict([(i,2) for i in range(10)])
697             d.addCallback(lambda res: self._set_versions(rollback))
698             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
699             d.addCallback(self._copy_shares, 4)
700             # we leave the storage in state 4
701             return d
702         d.addCallback(_created)
703         return d
704
705     def _copy_shares(self, ignored, index):
706         shares = self._client._storage._peers
707         # we need a deep copy
708         new_shares = {}
709         for peerid in shares:
710             new_shares[peerid] = {}
711             for shnum in shares[peerid]:
712                 new_shares[peerid][shnum] = shares[peerid][shnum]
713         self._copied_shares[index] = new_shares
714
715     def _set_versions(self, versionmap):
716         # versionmap maps shnums to which version (0,1,2,3,4) we want the
717         # share to be at. Any shnum which is left out of the map will stay at
718         # its current version.
719         shares = self._client._storage._peers
720         oldshares = self._copied_shares
721         for peerid in shares:
722             for shnum in shares[peerid]:
723                 if shnum in versionmap:
724                     index = versionmap[shnum]
725                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
726
727
728 class Servermap(unittest.TestCase, PublishMixin):
729     def setUp(self):
730         return self.publish_one()
731
732     def make_servermap(self, mode=MODE_CHECK, fn=None):
733         if fn is None:
734             fn = self._fn
735         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
736         d = smu.update()
737         return d
738
739     def update_servermap(self, oldmap, mode=MODE_CHECK):
740         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
741         d = smu.update()
742         return d
743
744     def failUnlessOneRecoverable(self, sm, num_shares):
745         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
746         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
747         best = sm.best_recoverable_version()
748         self.failIfEqual(best, None)
749         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
750         self.failUnlessEqual(len(sm.shares_available()), 1)
751         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
752         shnum, peerids = sm.make_sharemap().items()[0]
753         peerid = list(peerids)[0]
754         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
755         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
756         return sm
757
758     def test_basic(self):
759         d = defer.succeed(None)
760         ms = self.make_servermap
761         us = self.update_servermap
762
763         d.addCallback(lambda res: ms(mode=MODE_CHECK))
764         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
765         d.addCallback(lambda res: ms(mode=MODE_WRITE))
766         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
767         d.addCallback(lambda res: ms(mode=MODE_READ))
768         # this more stops at k+epsilon, and epsilon=k, so 6 shares
769         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
770         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
771         # this mode stops at 'k' shares
772         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
773
774         # and can we re-use the same servermap? Note that these are sorted in
775         # increasing order of number of servers queried, since once a server
776         # gets into the servermap, we'll always ask it for an update.
777         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
778         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
779         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
780         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
781         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
782         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
783         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
784         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
785         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
786
787         return d
788
789     def test_fetch_privkey(self):
790         d = defer.succeed(None)
791         # use the sibling filenode (which hasn't been used yet), and make
792         # sure it can fetch the privkey. The file is small, so the privkey
793         # will be fetched on the first (query) pass.
794         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
795         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
796
797         # create a new file, which is large enough to knock the privkey out
798         # of the early part of the file
799         LARGE = "These are Larger contents" * 200 # about 5KB
800         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
801         def _created(large_fn):
802             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
803             return self.make_servermap(MODE_WRITE, large_fn2)
804         d.addCallback(_created)
805         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
806         return d
807
808     def test_mark_bad(self):
809         d = defer.succeed(None)
810         ms = self.make_servermap
811         us = self.update_servermap
812
813         d.addCallback(lambda res: ms(mode=MODE_READ))
814         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
815         def _made_map(sm):
816             v = sm.best_recoverable_version()
817             vm = sm.make_versionmap()
818             shares = list(vm[v])
819             self.failUnlessEqual(len(shares), 6)
820             self._corrupted = set()
821             # mark the first 5 shares as corrupt, then update the servermap.
822             # The map should not have the marked shares it in any more, and
823             # new shares should be found to replace the missing ones.
824             for (shnum, peerid, timestamp) in shares:
825                 if shnum < 5:
826                     self._corrupted.add( (peerid, shnum) )
827                     sm.mark_bad_share(peerid, shnum, "")
828             return self.update_servermap(sm, MODE_WRITE)
829         d.addCallback(_made_map)
830         def _check_map(sm):
831             # this should find all 5 shares that weren't marked bad
832             v = sm.best_recoverable_version()
833             vm = sm.make_versionmap()
834             shares = list(vm[v])
835             for (peerid, shnum) in self._corrupted:
836                 peer_shares = sm.shares_on_peer(peerid)
837                 self.failIf(shnum in peer_shares,
838                             "%d was in %s" % (shnum, peer_shares))
839             self.failUnlessEqual(len(shares), 5)
840         d.addCallback(_check_map)
841         return d
842
843     def failUnlessNoneRecoverable(self, sm):
844         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
845         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
846         best = sm.best_recoverable_version()
847         self.failUnlessEqual(best, None)
848         self.failUnlessEqual(len(sm.shares_available()), 0)
849
850     def test_no_shares(self):
851         self._client._storage._peers = {} # delete all shares
852         ms = self.make_servermap
853         d = defer.succeed(None)
854
855         d.addCallback(lambda res: ms(mode=MODE_CHECK))
856         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
857
858         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
859         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
860
861         d.addCallback(lambda res: ms(mode=MODE_WRITE))
862         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
863
864         d.addCallback(lambda res: ms(mode=MODE_READ))
865         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
866
867         return d
868
869     def failUnlessNotQuiteEnough(self, sm):
870         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
871         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
872         best = sm.best_recoverable_version()
873         self.failUnlessEqual(best, None)
874         self.failUnlessEqual(len(sm.shares_available()), 1)
875         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
876         return sm
877
878     def test_not_quite_enough_shares(self):
879         s = self._client._storage
880         ms = self.make_servermap
881         num_shares = len(s._peers)
882         for peerid in s._peers:
883             s._peers[peerid] = {}
884             num_shares -= 1
885             if num_shares == 2:
886                 break
887         # now there ought to be only two shares left
888         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
889
890         d = defer.succeed(None)
891
892         d.addCallback(lambda res: ms(mode=MODE_CHECK))
893         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
894         d.addCallback(lambda sm:
895                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
896         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
897         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
898         d.addCallback(lambda res: ms(mode=MODE_WRITE))
899         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
900         d.addCallback(lambda res: ms(mode=MODE_READ))
901         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
902
903         return d
904
905
906
907 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
908     def setUp(self):
909         return self.publish_one()
910
911     def make_servermap(self, mode=MODE_READ, oldmap=None):
912         if oldmap is None:
913             oldmap = ServerMap()
914         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
915         d = smu.update()
916         return d
917
918     def abbrev_verinfo(self, verinfo):
919         if verinfo is None:
920             return None
921         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
922          offsets_tuple) = verinfo
923         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
924
925     def abbrev_verinfo_dict(self, verinfo_d):
926         output = {}
927         for verinfo,value in verinfo_d.items():
928             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
929              offsets_tuple) = verinfo
930             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
931         return output
932
933     def dump_servermap(self, servermap):
934         print "SERVERMAP", servermap
935         print "RECOVERABLE", [self.abbrev_verinfo(v)
936                               for v in servermap.recoverable_versions()]
937         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
938         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
939
940     def do_download(self, servermap, version=None):
941         if version is None:
942             version = servermap.best_recoverable_version()
943         r = Retrieve(self._fn, servermap, version)
944         return r.download()
945
946     def test_basic(self):
947         d = self.make_servermap()
948         def _do_retrieve(servermap):
949             self._smap = servermap
950             #self.dump_servermap(servermap)
951             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
952             return self.do_download(servermap)
953         d.addCallback(_do_retrieve)
954         def _retrieved(new_contents):
955             self.failUnlessEqual(new_contents, self.CONTENTS)
956         d.addCallback(_retrieved)
957         # we should be able to re-use the same servermap, both with and
958         # without updating it.
959         d.addCallback(lambda res: self.do_download(self._smap))
960         d.addCallback(_retrieved)
961         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
962         d.addCallback(lambda res: self.do_download(self._smap))
963         d.addCallback(_retrieved)
964         # clobbering the pubkey should make the servermap updater re-fetch it
965         def _clobber_pubkey(res):
966             self._fn._pubkey = None
967         d.addCallback(_clobber_pubkey)
968         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
969         d.addCallback(lambda res: self.do_download(self._smap))
970         d.addCallback(_retrieved)
971         return d
972
973     def test_all_shares_vanished(self):
974         d = self.make_servermap()
975         def _remove_shares(servermap):
976             for shares in self._storage._peers.values():
977                 shares.clear()
978             d1 = self.shouldFail(NotEnoughSharesError,
979                                  "test_all_shares_vanished",
980                                  "ran out of peers",
981                                  self.do_download, servermap)
982             return d1
983         d.addCallback(_remove_shares)
984         return d
985
986     def test_no_servers(self):
987         c2 = FakeClient(0)
988         self._fn._client = c2
989         # if there are no servers, then a MODE_READ servermap should come
990         # back empty
991         d = self.make_servermap()
992         def _check_servermap(servermap):
993             self.failUnlessEqual(servermap.best_recoverable_version(), None)
994             self.failIf(servermap.recoverable_versions())
995             self.failIf(servermap.unrecoverable_versions())
996             self.failIf(servermap.all_peers())
997         d.addCallback(_check_servermap)
998         return d
999     test_no_servers.timeout = 15
1000
1001     def test_no_servers_download(self):
1002         c2 = FakeClient(0)
1003         self._fn._client = c2
1004         d = self.shouldFail(UnrecoverableFileError,
1005                             "test_no_servers_download",
1006                             "no recoverable versions",
1007                             self._fn.download_best_version)
1008         def _restore(res):
1009             # a failed download that occurs while we aren't connected to
1010             # anybody should not prevent a subsequent download from working.
1011             # This isn't quite the webapi-driven test that #463 wants, but it
1012             # should be close enough.
1013             self._fn._client = self._client
1014             return self._fn.download_best_version()
1015         def _retrieved(new_contents):
1016             self.failUnlessEqual(new_contents, self.CONTENTS)
1017         d.addCallback(_restore)
1018         d.addCallback(_retrieved)
1019         return d
1020     test_no_servers_download.timeout = 15
1021
1022     def _test_corrupt_all(self, offset, substring,
1023                           should_succeed=False, corrupt_early=True,
1024                           failure_checker=None):
1025         d = defer.succeed(None)
1026         if corrupt_early:
1027             d.addCallback(corrupt, self._storage, offset)
1028         d.addCallback(lambda res: self.make_servermap())
1029         if not corrupt_early:
1030             d.addCallback(corrupt, self._storage, offset)
1031         def _do_retrieve(servermap):
1032             ver = servermap.best_recoverable_version()
1033             if ver is None and not should_succeed:
1034                 # no recoverable versions == not succeeding. The problem
1035                 # should be noted in the servermap's list of problems.
1036                 if substring:
1037                     allproblems = [str(f) for f in servermap.problems]
1038                     self.failUnless(substring in "".join(allproblems))
1039                 return servermap
1040             if should_succeed:
1041                 d1 = self._fn.download_version(servermap, ver)
1042                 d1.addCallback(lambda new_contents:
1043                                self.failUnlessEqual(new_contents, self.CONTENTS))
1044             else:
1045                 d1 = self.shouldFail(NotEnoughSharesError,
1046                                      "_corrupt_all(offset=%s)" % (offset,),
1047                                      substring,
1048                                      self._fn.download_version, servermap, ver)
1049             if failure_checker:
1050                 d1.addCallback(failure_checker)
1051             d1.addCallback(lambda res: servermap)
1052             return d1
1053         d.addCallback(_do_retrieve)
1054         return d
1055
1056     def test_corrupt_all_verbyte(self):
1057         # when the version byte is not 0, we hit an assertion error in
1058         # unpack_share().
1059         d = self._test_corrupt_all(0, "AssertionError")
1060         def _check_servermap(servermap):
1061             # and the dump should mention the problems
1062             s = StringIO()
1063             dump = servermap.dump(s).getvalue()
1064             self.failUnless("10 PROBLEMS" in dump, dump)
1065         d.addCallback(_check_servermap)
1066         return d
1067
1068     def test_corrupt_all_seqnum(self):
1069         # a corrupt sequence number will trigger a bad signature
1070         return self._test_corrupt_all(1, "signature is invalid")
1071
1072     def test_corrupt_all_R(self):
1073         # a corrupt root hash will trigger a bad signature
1074         return self._test_corrupt_all(9, "signature is invalid")
1075
1076     def test_corrupt_all_IV(self):
1077         # a corrupt salt/IV will trigger a bad signature
1078         return self._test_corrupt_all(41, "signature is invalid")
1079
1080     def test_corrupt_all_k(self):
1081         # a corrupt 'k' will trigger a bad signature
1082         return self._test_corrupt_all(57, "signature is invalid")
1083
1084     def test_corrupt_all_N(self):
1085         # a corrupt 'N' will trigger a bad signature
1086         return self._test_corrupt_all(58, "signature is invalid")
1087
1088     def test_corrupt_all_segsize(self):
1089         # a corrupt segsize will trigger a bad signature
1090         return self._test_corrupt_all(59, "signature is invalid")
1091
1092     def test_corrupt_all_datalen(self):
1093         # a corrupt data length will trigger a bad signature
1094         return self._test_corrupt_all(67, "signature is invalid")
1095
1096     def test_corrupt_all_pubkey(self):
1097         # a corrupt pubkey won't match the URI's fingerprint. We need to
1098         # remove the pubkey from the filenode, or else it won't bother trying
1099         # to update it.
1100         self._fn._pubkey = None
1101         return self._test_corrupt_all("pubkey",
1102                                       "pubkey doesn't match fingerprint")
1103
1104     def test_corrupt_all_sig(self):
1105         # a corrupt signature is a bad one
1106         # the signature runs from about [543:799], depending upon the length
1107         # of the pubkey
1108         return self._test_corrupt_all("signature", "signature is invalid")
1109
1110     def test_corrupt_all_share_hash_chain_number(self):
1111         # a corrupt share hash chain entry will show up as a bad hash. If we
1112         # mangle the first byte, that will look like a bad hash number,
1113         # causing an IndexError
1114         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1115
1116     def test_corrupt_all_share_hash_chain_hash(self):
1117         # a corrupt share hash chain entry will show up as a bad hash. If we
1118         # mangle a few bytes in, that will look like a bad hash.
1119         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1120
1121     def test_corrupt_all_block_hash_tree(self):
1122         return self._test_corrupt_all("block_hash_tree",
1123                                       "block hash tree failure")
1124
1125     def test_corrupt_all_block(self):
1126         return self._test_corrupt_all("share_data", "block hash tree failure")
1127
1128     def test_corrupt_all_encprivkey(self):
1129         # a corrupted privkey won't even be noticed by the reader, only by a
1130         # writer.
1131         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1132
1133
1134     def test_corrupt_all_seqnum_late(self):
1135         # corrupting the seqnum between mapupdate and retrieve should result
1136         # in NotEnoughSharesError, since each share will look invalid
1137         def _check(res):
1138             f = res[0]
1139             self.failUnless(f.check(NotEnoughSharesError))
1140             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1141         return self._test_corrupt_all(1, "ran out of peers",
1142                                       corrupt_early=False,
1143                                       failure_checker=_check)
1144
1145     def test_corrupt_all_block_hash_tree_late(self):
1146         def _check(res):
1147             f = res[0]
1148             self.failUnless(f.check(NotEnoughSharesError))
1149         return self._test_corrupt_all("block_hash_tree",
1150                                       "block hash tree failure",
1151                                       corrupt_early=False,
1152                                       failure_checker=_check)
1153
1154
1155     def test_corrupt_all_block_late(self):
1156         def _check(res):
1157             f = res[0]
1158             self.failUnless(f.check(NotEnoughSharesError))
1159         return self._test_corrupt_all("share_data", "block hash tree failure",
1160                                       corrupt_early=False,
1161                                       failure_checker=_check)
1162
1163
1164     def test_basic_pubkey_at_end(self):
1165         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1166         # download to succeed but forcing a bunch of retries first. Note that
1167         # this is rather pessimistic: our Retrieve process will throw away
1168         # the whole share if the pubkey is bad, even though the rest of the
1169         # share might be good.
1170
1171         self._fn._pubkey = None
1172         k = self._fn.get_required_shares()
1173         N = self._fn.get_total_shares()
1174         d = defer.succeed(None)
1175         d.addCallback(corrupt, self._storage, "pubkey",
1176                       shnums_to_corrupt=range(0, N-k))
1177         d.addCallback(lambda res: self.make_servermap())
1178         def _do_retrieve(servermap):
1179             self.failUnless(servermap.problems)
1180             self.failUnless("pubkey doesn't match fingerprint"
1181                             in str(servermap.problems[0]))
1182             ver = servermap.best_recoverable_version()
1183             r = Retrieve(self._fn, servermap, ver)
1184             return r.download()
1185         d.addCallback(_do_retrieve)
1186         d.addCallback(lambda new_contents:
1187                       self.failUnlessEqual(new_contents, self.CONTENTS))
1188         return d
1189
1190     def test_corrupt_some(self):
1191         # corrupt the data of first five shares (so the servermap thinks
1192         # they're good but retrieve marks them as bad), so that the
1193         # MODE_READ set of 6 will be insufficient, forcing node.download to
1194         # retry with more servers.
1195         corrupt(None, self._storage, "share_data", range(5))
1196         d = self.make_servermap()
1197         def _do_retrieve(servermap):
1198             ver = servermap.best_recoverable_version()
1199             self.failUnless(ver)
1200             return self._fn.download_best_version()
1201         d.addCallback(_do_retrieve)
1202         d.addCallback(lambda new_contents:
1203                       self.failUnlessEqual(new_contents, self.CONTENTS))
1204         return d
1205
1206     def test_download_fails(self):
1207         corrupt(None, self._storage, "signature")
1208         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1209                             "no recoverable versions",
1210                             self._fn.download_best_version)
1211         return d
1212
1213
1214 class CheckerMixin:
1215     def check_good(self, r, where):
1216         self.failUnless(r.is_healthy(), where)
1217         return r
1218
1219     def check_bad(self, r, where):
1220         self.failIf(r.is_healthy(), where)
1221         return r
1222
1223     def check_expected_failure(self, r, expected_exception, substring, where):
1224         for (peerid, storage_index, shnum, f) in r.problems:
1225             if f.check(expected_exception):
1226                 self.failUnless(substring in str(f),
1227                                 "%s: substring '%s' not in '%s'" %
1228                                 (where, substring, str(f)))
1229                 return
1230         self.fail("%s: didn't see expected exception %s in problems %s" %
1231                   (where, expected_exception, r.problems))
1232
1233
1234 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1235     def setUp(self):
1236         return self.publish_one()
1237
1238
1239     def test_check_good(self):
1240         d = self._fn.check(Monitor())
1241         d.addCallback(self.check_good, "test_check_good")
1242         return d
1243
1244     def test_check_no_shares(self):
1245         for shares in self._storage._peers.values():
1246             shares.clear()
1247         d = self._fn.check(Monitor())
1248         d.addCallback(self.check_bad, "test_check_no_shares")
1249         return d
1250
1251     def test_check_not_enough_shares(self):
1252         for shares in self._storage._peers.values():
1253             for shnum in shares.keys():
1254                 if shnum > 0:
1255                     del shares[shnum]
1256         d = self._fn.check(Monitor())
1257         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1258         return d
1259
1260     def test_check_all_bad_sig(self):
1261         corrupt(None, self._storage, 1) # bad sig
1262         d = self._fn.check(Monitor())
1263         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1264         return d
1265
1266     def test_check_all_bad_blocks(self):
1267         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1268         # the Checker won't notice this.. it doesn't look at actual data
1269         d = self._fn.check(Monitor())
1270         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1271         return d
1272
1273     def test_verify_good(self):
1274         d = self._fn.check(Monitor(), verify=True)
1275         d.addCallback(self.check_good, "test_verify_good")
1276         return d
1277
1278     def test_verify_all_bad_sig(self):
1279         corrupt(None, self._storage, 1) # bad sig
1280         d = self._fn.check(Monitor(), verify=True)
1281         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1282         return d
1283
1284     def test_verify_one_bad_sig(self):
1285         corrupt(None, self._storage, 1, [9]) # bad sig
1286         d = self._fn.check(Monitor(), verify=True)
1287         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1288         return d
1289
1290     def test_verify_one_bad_block(self):
1291         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1292         # the Verifier *will* notice this, since it examines every byte
1293         d = self._fn.check(Monitor(), verify=True)
1294         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1295         d.addCallback(self.check_expected_failure,
1296                       CorruptShareError, "block hash tree failure",
1297                       "test_verify_one_bad_block")
1298         return d
1299
1300     def test_verify_one_bad_sharehash(self):
1301         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1302         d = self._fn.check(Monitor(), verify=True)
1303         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1304         d.addCallback(self.check_expected_failure,
1305                       CorruptShareError, "corrupt hashes",
1306                       "test_verify_one_bad_sharehash")
1307         return d
1308
1309     def test_verify_one_bad_encprivkey(self):
1310         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1311         d = self._fn.check(Monitor(), verify=True)
1312         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1313         d.addCallback(self.check_expected_failure,
1314                       CorruptShareError, "invalid privkey",
1315                       "test_verify_one_bad_encprivkey")
1316         return d
1317
1318     def test_verify_one_bad_encprivkey_uncheckable(self):
1319         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1320         readonly_fn = self._fn.get_readonly()
1321         # a read-only node has no way to validate the privkey
1322         d = readonly_fn.check(Monitor(), verify=True)
1323         d.addCallback(self.check_good,
1324                       "test_verify_one_bad_encprivkey_uncheckable")
1325         return d
1326
1327 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1328
1329     def get_shares(self, s):
1330         all_shares = {} # maps (peerid, shnum) to share data
1331         for peerid in s._peers:
1332             shares = s._peers[peerid]
1333             for shnum in shares:
1334                 data = shares[shnum]
1335                 all_shares[ (peerid, shnum) ] = data
1336         return all_shares
1337
1338     def copy_shares(self, ignored=None):
1339         self.old_shares.append(self.get_shares(self._storage))
1340
1341     def test_repair_nop(self):
1342         self.old_shares = []
1343         d = self.publish_one()
1344         d.addCallback(self.copy_shares)
1345         d.addCallback(lambda res: self._fn.check(Monitor()))
1346         d.addCallback(lambda check_results: self._fn.repair(check_results))
1347         def _check_results(rres):
1348             self.failUnless(IRepairResults.providedBy(rres))
1349             # TODO: examine results
1350
1351             self.copy_shares()
1352
1353             initial_shares = self.old_shares[0]
1354             new_shares = self.old_shares[1]
1355             # TODO: this really shouldn't change anything. When we implement
1356             # a "minimal-bandwidth" repairer", change this test to assert:
1357             #self.failUnlessEqual(new_shares, initial_shares)
1358
1359             # all shares should be in the same place as before
1360             self.failUnlessEqual(set(initial_shares.keys()),
1361                                  set(new_shares.keys()))
1362             # but they should all be at a newer seqnum. The IV will be
1363             # different, so the roothash will be too.
1364             for key in initial_shares:
1365                 (version0,
1366                  seqnum0,
1367                  root_hash0,
1368                  IV0,
1369                  k0, N0, segsize0, datalen0,
1370                  o0) = unpack_header(initial_shares[key])
1371                 (version1,
1372                  seqnum1,
1373                  root_hash1,
1374                  IV1,
1375                  k1, N1, segsize1, datalen1,
1376                  o1) = unpack_header(new_shares[key])
1377                 self.failUnlessEqual(version0, version1)
1378                 self.failUnlessEqual(seqnum0+1, seqnum1)
1379                 self.failUnlessEqual(k0, k1)
1380                 self.failUnlessEqual(N0, N1)
1381                 self.failUnlessEqual(segsize0, segsize1)
1382                 self.failUnlessEqual(datalen0, datalen1)
1383         d.addCallback(_check_results)
1384         return d
1385
1386     def failIfSharesChanged(self, ignored=None):
1387         old_shares = self.old_shares[-2]
1388         current_shares = self.old_shares[-1]
1389         self.failUnlessEqual(old_shares, current_shares)
1390
1391     def test_merge(self):
1392         self.old_shares = []
1393         d = self.publish_multiple()
1394         # repair will refuse to merge multiple highest seqnums unless you
1395         # pass force=True
1396         d.addCallback(lambda res:
1397                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1398                                           1:4,3:4,5:4,7:4,9:4}))
1399         d.addCallback(self.copy_shares)
1400         d.addCallback(lambda res: self._fn.check(Monitor()))
1401         def _try_repair(check_results):
1402             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1403             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1404                                  self._fn.repair, check_results)
1405             d2.addCallback(self.copy_shares)
1406             d2.addCallback(self.failIfSharesChanged)
1407             d2.addCallback(lambda res: check_results)
1408             return d2
1409         d.addCallback(_try_repair)
1410         d.addCallback(lambda check_results:
1411                       self._fn.repair(check_results, force=True))
1412         # this should give us 10 shares of the highest roothash
1413         def _check_repair_results(rres):
1414             pass # TODO
1415         d.addCallback(_check_repair_results)
1416         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1417         def _check_smap(smap):
1418             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1419             self.failIf(smap.unrecoverable_versions())
1420             # now, which should have won?
1421             roothash_s4a = self.get_roothash_for(3)
1422             roothash_s4b = self.get_roothash_for(4)
1423             if roothash_s4b > roothash_s4a:
1424                 expected_contents = self.CONTENTS[4]
1425             else:
1426                 expected_contents = self.CONTENTS[3]
1427             new_versionid = smap.best_recoverable_version()
1428             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1429             d2 = self._fn.download_version(smap, new_versionid)
1430             d2.addCallback(self.failUnlessEqual, expected_contents)
1431             return d2
1432         d.addCallback(_check_smap)
1433         return d
1434
1435     def test_non_merge(self):
1436         self.old_shares = []
1437         d = self.publish_multiple()
1438         # repair should not refuse a repair that doesn't need to merge. In
1439         # this case, we combine v2 with v3. The repair should ignore v2 and
1440         # copy v3 into a new v5.
1441         d.addCallback(lambda res:
1442                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1443                                           1:3,3:3,5:3,7:3,9:3}))
1444         d.addCallback(lambda res: self._fn.check(Monitor()))
1445         d.addCallback(lambda check_results: self._fn.repair(check_results))
1446         # this should give us 10 shares of v3
1447         def _check_repair_results(rres):
1448             pass # TODO
1449         d.addCallback(_check_repair_results)
1450         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1451         def _check_smap(smap):
1452             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1453             self.failIf(smap.unrecoverable_versions())
1454             # now, which should have won?
1455             roothash_s4a = self.get_roothash_for(3)
1456             expected_contents = self.CONTENTS[3]
1457             new_versionid = smap.best_recoverable_version()
1458             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1459             d2 = self._fn.download_version(smap, new_versionid)
1460             d2.addCallback(self.failUnlessEqual, expected_contents)
1461             return d2
1462         d.addCallback(_check_smap)
1463         return d
1464
1465     def get_roothash_for(self, index):
1466         # return the roothash for the first share we see in the saved set
1467         shares = self._copied_shares[index]
1468         for peerid in shares:
1469             for shnum in shares[peerid]:
1470                 share = shares[peerid][shnum]
1471                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1472                           unpack_header(share)
1473                 return root_hash
1474
1475 class MultipleEncodings(unittest.TestCase):
1476     def setUp(self):
1477         self.CONTENTS = "New contents go here"
1478         num_peers = 20
1479         self._client = FakeClient(num_peers)
1480         self._storage = self._client._storage
1481         d = self._client.create_mutable_file(self.CONTENTS)
1482         def _created(node):
1483             self._fn = node
1484         d.addCallback(_created)
1485         return d
1486
1487     def _encode(self, k, n, data):
1488         # encode 'data' into a peerid->shares dict.
1489
1490         fn2 = FastMutableFileNode(self._client)
1491         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1492         # and _fingerprint
1493         fn = self._fn
1494         fn2.init_from_uri(fn.get_uri())
1495         # then we copy over other fields that are normally fetched from the
1496         # existing shares
1497         fn2._pubkey = fn._pubkey
1498         fn2._privkey = fn._privkey
1499         fn2._encprivkey = fn._encprivkey
1500         # and set the encoding parameters to something completely different
1501         fn2._required_shares = k
1502         fn2._total_shares = n
1503
1504         s = self._client._storage
1505         s._peers = {} # clear existing storage
1506         p2 = Publish(fn2, None)
1507         d = p2.publish(data)
1508         def _published(res):
1509             shares = s._peers
1510             s._peers = {}
1511             return shares
1512         d.addCallback(_published)
1513         return d
1514
1515     def make_servermap(self, mode=MODE_READ, oldmap=None):
1516         if oldmap is None:
1517             oldmap = ServerMap()
1518         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1519         d = smu.update()
1520         return d
1521
1522     def test_multiple_encodings(self):
1523         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1524         # then mix up the shares, to make sure that download survives seeing
1525         # a variety of encodings. This is actually kind of tricky to set up.
1526
1527         contents1 = "Contents for encoding 1 (3-of-10) go here"
1528         contents2 = "Contents for encoding 2 (4-of-9) go here"
1529         contents3 = "Contents for encoding 3 (4-of-7) go here"
1530
1531         # we make a retrieval object that doesn't know what encoding
1532         # parameters to use
1533         fn3 = FastMutableFileNode(self._client)
1534         fn3.init_from_uri(self._fn.get_uri())
1535
1536         # now we upload a file through fn1, and grab its shares
1537         d = self._encode(3, 10, contents1)
1538         def _encoded_1(shares):
1539             self._shares1 = shares
1540         d.addCallback(_encoded_1)
1541         d.addCallback(lambda res: self._encode(4, 9, contents2))
1542         def _encoded_2(shares):
1543             self._shares2 = shares
1544         d.addCallback(_encoded_2)
1545         d.addCallback(lambda res: self._encode(4, 7, contents3))
1546         def _encoded_3(shares):
1547             self._shares3 = shares
1548         d.addCallback(_encoded_3)
1549
1550         def _merge(res):
1551             log.msg("merging sharelists")
1552             # we merge the shares from the two sets, leaving each shnum in
1553             # its original location, but using a share from set1 or set2
1554             # according to the following sequence:
1555             #
1556             #  4-of-9  a  s2
1557             #  4-of-9  b  s2
1558             #  4-of-7  c   s3
1559             #  4-of-9  d  s2
1560             #  3-of-9  e s1
1561             #  3-of-9  f s1
1562             #  3-of-9  g s1
1563             #  4-of-9  h  s2
1564             #
1565             # so that neither form can be recovered until fetch [f], at which
1566             # point version-s1 (the 3-of-10 form) should be recoverable. If
1567             # the implementation latches on to the first version it sees,
1568             # then s2 will be recoverable at fetch [g].
1569
1570             # Later, when we implement code that handles multiple versions,
1571             # we can use this framework to assert that all recoverable
1572             # versions are retrieved, and test that 'epsilon' does its job
1573
1574             places = [2, 2, 3, 2, 1, 1, 1, 2]
1575
1576             sharemap = {}
1577
1578             for i,peerid in enumerate(self._client._peerids):
1579                 peerid_s = shortnodeid_b2a(peerid)
1580                 for shnum in self._shares1.get(peerid, {}):
1581                     if shnum < len(places):
1582                         which = places[shnum]
1583                     else:
1584                         which = "x"
1585                     self._client._storage._peers[peerid] = peers = {}
1586                     in_1 = shnum in self._shares1[peerid]
1587                     in_2 = shnum in self._shares2.get(peerid, {})
1588                     in_3 = shnum in self._shares3.get(peerid, {})
1589                     #print peerid_s, shnum, which, in_1, in_2, in_3
1590                     if which == 1:
1591                         if in_1:
1592                             peers[shnum] = self._shares1[peerid][shnum]
1593                             sharemap[shnum] = peerid
1594                     elif which == 2:
1595                         if in_2:
1596                             peers[shnum] = self._shares2[peerid][shnum]
1597                             sharemap[shnum] = peerid
1598                     elif which == 3:
1599                         if in_3:
1600                             peers[shnum] = self._shares3[peerid][shnum]
1601                             sharemap[shnum] = peerid
1602
1603             # we don't bother placing any other shares
1604             # now sort the sequence so that share 0 is returned first
1605             new_sequence = [sharemap[shnum]
1606                             for shnum in sorted(sharemap.keys())]
1607             self._client._storage._sequence = new_sequence
1608             log.msg("merge done")
1609         d.addCallback(_merge)
1610         d.addCallback(lambda res: fn3.download_best_version())
1611         def _retrieved(new_contents):
1612             # the current specified behavior is "first version recoverable"
1613             self.failUnlessEqual(new_contents, contents1)
1614         d.addCallback(_retrieved)
1615         return d
1616
1617
1618 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1619
1620     def setUp(self):
1621         return self.publish_multiple()
1622
1623     def test_multiple_versions(self):
1624         # if we see a mix of versions in the grid, download_best_version
1625         # should get the latest one
1626         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1627         d = self._fn.download_best_version()
1628         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1629         # and the checker should report problems
1630         d.addCallback(lambda res: self._fn.check(Monitor()))
1631         d.addCallback(self.check_bad, "test_multiple_versions")
1632
1633         # but if everything is at version 2, that's what we should download
1634         d.addCallback(lambda res:
1635                       self._set_versions(dict([(i,2) for i in range(10)])))
1636         d.addCallback(lambda res: self._fn.download_best_version())
1637         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1638         # if exactly one share is at version 3, we should still get v2
1639         d.addCallback(lambda res:
1640                       self._set_versions({0:3}))
1641         d.addCallback(lambda res: self._fn.download_best_version())
1642         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1643         # but the servermap should see the unrecoverable version. This
1644         # depends upon the single newer share being queried early.
1645         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1646         def _check_smap(smap):
1647             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1648             newer = smap.unrecoverable_newer_versions()
1649             self.failUnlessEqual(len(newer), 1)
1650             verinfo, health = newer.items()[0]
1651             self.failUnlessEqual(verinfo[0], 4)
1652             self.failUnlessEqual(health, (1,3))
1653             self.failIf(smap.needs_merge())
1654         d.addCallback(_check_smap)
1655         # if we have a mix of two parallel versions (s4a and s4b), we could
1656         # recover either
1657         d.addCallback(lambda res:
1658                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1659                                           1:4,3:4,5:4,7:4,9:4}))
1660         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1661         def _check_smap_mixed(smap):
1662             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1663             newer = smap.unrecoverable_newer_versions()
1664             self.failUnlessEqual(len(newer), 0)
1665             self.failUnless(smap.needs_merge())
1666         d.addCallback(_check_smap_mixed)
1667         d.addCallback(lambda res: self._fn.download_best_version())
1668         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1669                                                   res == self.CONTENTS[4]))
1670         return d
1671
1672     def test_replace(self):
1673         # if we see a mix of versions in the grid, we should be able to
1674         # replace them all with a newer version
1675
1676         # if exactly one share is at version 3, we should download (and
1677         # replace) v2, and the result should be v4. Note that the index we
1678         # give to _set_versions is different than the sequence number.
1679         target = dict([(i,2) for i in range(10)]) # seqnum3
1680         target[0] = 3 # seqnum4
1681         self._set_versions(target)
1682
1683         def _modify(oldversion, servermap, first_time):
1684             return oldversion + " modified"
1685         d = self._fn.modify(_modify)
1686         d.addCallback(lambda res: self._fn.download_best_version())
1687         expected = self.CONTENTS[2] + " modified"
1688         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1689         # and the servermap should indicate that the outlier was replaced too
1690         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1691         def _check_smap(smap):
1692             self.failUnlessEqual(smap.highest_seqnum(), 5)
1693             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1694             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1695         d.addCallback(_check_smap)
1696         return d
1697
1698
1699 class Utils(unittest.TestCase):
1700     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1701         # we compare this against sets of integers
1702         x = set(range(x_start, x_start+x_length))
1703         y = set(range(y_start, y_start+y_length))
1704         should_be_inside = x.issubset(y)
1705         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1706                                                          y_start, y_length),
1707                              str((x_start, x_length, y_start, y_length)))
1708
1709     def test_cache_inside(self):
1710         c = ResponseCache()
1711         x_start = 10
1712         x_length = 5
1713         for y_start in range(8, 17):
1714             for y_length in range(8):
1715                 self._do_inside(c, x_start, x_length, y_start, y_length)
1716
1717     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1718         # we compare this against sets of integers
1719         x = set(range(x_start, x_start+x_length))
1720         y = set(range(y_start, y_start+y_length))
1721         overlap = bool(x.intersection(y))
1722         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1723                                                       y_start, y_length),
1724                              str((x_start, x_length, y_start, y_length)))
1725
1726     def test_cache_overlap(self):
1727         c = ResponseCache()
1728         x_start = 10
1729         x_length = 5
1730         for y_start in range(8, 17):
1731             for y_length in range(8):
1732                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1733
1734     def test_cache(self):
1735         c = ResponseCache()
1736         # xdata = base62.b2a(os.urandom(100))[:100]
1737         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1738         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1739         nope = (None, None)
1740         c.add("v1", 1, 0, xdata, "time0")
1741         c.add("v1", 1, 2000, ydata, "time1")
1742         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1743         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1744         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1745         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1746         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1747         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1748         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1749         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1750         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1752         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1753         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1754         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1755         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1756         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1757         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1758         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1759         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1760
1761         # optional: join fragments
1762         c = ResponseCache()
1763         c.add("v1", 1, 0, xdata[:10], "time0")
1764         c.add("v1", 1, 10, xdata[10:20], "time1")
1765         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1766
1767 class Exceptions(unittest.TestCase):
1768     def test_repr(self):
1769         nmde = NeedMoreDataError(100, 50, 100)
1770         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1771         ucwe = UncoordinatedWriteError()
1772         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1773
1774 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1775 # instances which always succeed. So we need a less-fake one.
1776
1777 class IntentionalError(Exception):
1778     pass
1779
1780 class LocalWrapper:
1781     def __init__(self, original):
1782         self.original = original
1783         self.broken = False
1784         self.post_call_notifier = None
1785     def callRemote(self, methname, *args, **kwargs):
1786         def _call():
1787             if self.broken:
1788                 raise IntentionalError("I was asked to break")
1789             meth = getattr(self.original, "remote_" + methname)
1790             return meth(*args, **kwargs)
1791         d = fireEventually()
1792         d.addCallback(lambda res: _call())
1793         if self.post_call_notifier:
1794             d.addCallback(self.post_call_notifier, methname)
1795         return d
1796
1797 class LessFakeClient(FakeClient):
1798
1799     def __init__(self, basedir, num_peers=10):
1800         self._num_peers = num_peers
1801         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1802                          for i in range(self._num_peers)]
1803         self._connections = {}
1804         for peerid in self._peerids:
1805             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1806             make_dirs(peerdir)
1807             ss = StorageServer(peerdir, peerid)
1808             lw = LocalWrapper(ss)
1809             self._connections[peerid] = lw
1810         self.nodeid = "fakenodeid"
1811
1812
1813 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1814     def test_publish_surprise(self):
1815         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1816         self.client = LessFakeClient(basedir)
1817         d = self.client.create_mutable_file("contents 1")
1818         def _created(n):
1819             d = defer.succeed(None)
1820             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1821             def _got_smap1(smap):
1822                 # stash the old state of the file
1823                 self.old_map = smap
1824             d.addCallback(_got_smap1)
1825             # then modify the file, leaving the old map untouched
1826             d.addCallback(lambda res: log.msg("starting winning write"))
1827             d.addCallback(lambda res: n.overwrite("contents 2"))
1828             # now attempt to modify the file with the old servermap. This
1829             # will look just like an uncoordinated write, in which every
1830             # single share got updated between our mapupdate and our publish
1831             d.addCallback(lambda res: log.msg("starting doomed write"))
1832             d.addCallback(lambda res:
1833                           self.shouldFail(UncoordinatedWriteError,
1834                                           "test_publish_surprise", None,
1835                                           n.upload,
1836                                           "contents 2a", self.old_map))
1837             return d
1838         d.addCallback(_created)
1839         return d
1840
1841     def test_retrieve_surprise(self):
1842         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1843         self.client = LessFakeClient(basedir)
1844         d = self.client.create_mutable_file("contents 1")
1845         def _created(n):
1846             d = defer.succeed(None)
1847             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1848             def _got_smap1(smap):
1849                 # stash the old state of the file
1850                 self.old_map = smap
1851             d.addCallback(_got_smap1)
1852             # then modify the file, leaving the old map untouched
1853             d.addCallback(lambda res: log.msg("starting winning write"))
1854             d.addCallback(lambda res: n.overwrite("contents 2"))
1855             # now attempt to retrieve the old version with the old servermap.
1856             # This will look like someone has changed the file since we
1857             # updated the servermap.
1858             d.addCallback(lambda res: n._cache._clear())
1859             d.addCallback(lambda res: log.msg("starting doomed read"))
1860             d.addCallback(lambda res:
1861                           self.shouldFail(NotEnoughSharesError,
1862                                           "test_retrieve_surprise",
1863                                           "ran out of peers: have 0 shares (k=3)",
1864                                           n.download_version,
1865                                           self.old_map,
1866                                           self.old_map.best_recoverable_version(),
1867                                           ))
1868             return d
1869         d.addCallback(_created)
1870         return d
1871
1872     def test_unexpected_shares(self):
1873         # upload the file, take a servermap, shut down one of the servers,
1874         # upload it again (causing shares to appear on a new server), then
1875         # upload using the old servermap. The last upload should fail with an
1876         # UncoordinatedWriteError, because of the shares that didn't appear
1877         # in the servermap.
1878         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1879         self.client = LessFakeClient(basedir)
1880         d = self.client.create_mutable_file("contents 1")
1881         def _created(n):
1882             d = defer.succeed(None)
1883             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1884             def _got_smap1(smap):
1885                 # stash the old state of the file
1886                 self.old_map = smap
1887                 # now shut down one of the servers
1888                 peer0 = list(smap.make_sharemap()[0])[0]
1889                 self.client._connections.pop(peer0)
1890                 # then modify the file, leaving the old map untouched
1891                 log.msg("starting winning write")
1892                 return n.overwrite("contents 2")
1893             d.addCallback(_got_smap1)
1894             # now attempt to modify the file with the old servermap. This
1895             # will look just like an uncoordinated write, in which every
1896             # single share got updated between our mapupdate and our publish
1897             d.addCallback(lambda res: log.msg("starting doomed write"))
1898             d.addCallback(lambda res:
1899                           self.shouldFail(UncoordinatedWriteError,
1900                                           "test_surprise", None,
1901                                           n.upload,
1902                                           "contents 2a", self.old_map))
1903             return d
1904         d.addCallback(_created)
1905         return d
1906
1907     def test_bad_server(self):
1908         # Break one server, then create the file: the initial publish should
1909         # complete with an alternate server. Breaking a second server should
1910         # not prevent an update from succeeding either.
1911         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1912         self.client = LessFakeClient(basedir, 20)
1913         # to make sure that one of the initial peers is broken, we have to
1914         # get creative. We create the keys, so we can figure out the storage
1915         # index, but we hold off on doing the initial publish until we've
1916         # broken the server on which the first share wants to be stored.
1917         n = FastMutableFileNode(self.client)
1918         d = defer.succeed(None)
1919         d.addCallback(n._generate_pubprivkeys)
1920         d.addCallback(n._generated)
1921         def _break_peer0(res):
1922             si = n.get_storage_index()
1923             peerlist = self.client.get_permuted_peers("storage", si)
1924             peerid0, connection0 = peerlist[0]
1925             peerid1, connection1 = peerlist[1]
1926             connection0.broken = True
1927             self.connection1 = connection1
1928         d.addCallback(_break_peer0)
1929         # now let the initial publish finally happen
1930         d.addCallback(lambda res: n._upload("contents 1", None))
1931         # that ought to work
1932         d.addCallback(lambda res: n.download_best_version())
1933         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1934         # now break the second peer
1935         def _break_peer1(res):
1936             self.connection1.broken = True
1937         d.addCallback(_break_peer1)
1938         d.addCallback(lambda res: n.overwrite("contents 2"))
1939         # that ought to work too
1940         d.addCallback(lambda res: n.download_best_version())
1941         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1942         return d
1943
1944     def test_bad_server_overlap(self):
1945         # like test_bad_server, but with no extra unused servers to fall back
1946         # upon. This means that we must re-use a server which we've already
1947         # used. If we don't remember the fact that we sent them one share
1948         # already, we'll mistakenly think we're experiencing an
1949         # UncoordinatedWriteError.
1950
1951         # Break one server, then create the file: the initial publish should
1952         # complete with an alternate server. Breaking a second server should
1953         # not prevent an update from succeeding either.
1954         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1955         self.client = LessFakeClient(basedir, 10)
1956
1957         peerids = sorted(self.client._connections.keys())
1958         self.client._connections[peerids[0]].broken = True
1959
1960         d = self.client.create_mutable_file("contents 1")
1961         def _created(n):
1962             d = n.download_best_version()
1963             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1964             # now break one of the remaining servers
1965             def _break_second_server(res):
1966                 self.client._connections[peerids[1]].broken = True
1967             d.addCallback(_break_second_server)
1968             d.addCallback(lambda res: n.overwrite("contents 2"))
1969             # that ought to work too
1970             d.addCallback(lambda res: n.download_best_version())
1971             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1972             return d
1973         d.addCallback(_created)
1974         return d
1975
1976     def test_publish_all_servers_bad(self):
1977         # Break all servers: the publish should fail
1978         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1979         self.client = LessFakeClient(basedir, 20)
1980         for connection in self.client._connections.values():
1981             connection.broken = True
1982         d = self.shouldFail(NotEnoughServersError,
1983                             "test_publish_all_servers_bad",
1984                             "Ran out of non-bad servers",
1985                             self.client.create_mutable_file, "contents")
1986         return d
1987
1988     def test_publish_no_servers(self):
1989         # no servers at all: the publish should fail
1990         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1991         self.client = LessFakeClient(basedir, 0)
1992         d = self.shouldFail(NotEnoughServersError,
1993                             "test_publish_no_servers",
1994                             "Ran out of non-bad servers",
1995                             self.client.create_mutable_file, "contents")
1996         return d
1997     test_publish_no_servers.timeout = 30
1998
1999
2000     def test_privkey_query_error(self):
2001         # when a servermap is updated with MODE_WRITE, it tries to get the
2002         # privkey. Something might go wrong during this query attempt.
2003         self.client = FakeClient(20)
2004         # we need some contents that are large enough to push the privkey out
2005         # of the early part of the file
2006         LARGE = "These are Larger contents" * 200 # about 5KB
2007         d = self.client.create_mutable_file(LARGE)
2008         def _created(n):
2009             self.uri = n.get_uri()
2010             self.n2 = self.client.create_node_from_uri(self.uri)
2011             # we start by doing a map update to figure out which is the first
2012             # server.
2013             return n.get_servermap(MODE_WRITE)
2014         d.addCallback(_created)
2015         d.addCallback(lambda res: fireEventually(res))
2016         def _got_smap1(smap):
2017             peer0 = list(smap.make_sharemap()[0])[0]
2018             # we tell the server to respond to this peer first, so that it
2019             # will be asked for the privkey first
2020             self.client._storage._sequence = [peer0]
2021             # now we make the peer fail their second query
2022             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2023         d.addCallback(_got_smap1)
2024         # now we update a servermap from a new node (which doesn't have the
2025         # privkey yet, forcing it to use a separate privkey query). Each
2026         # query response will trigger a privkey query, and since we're using
2027         # _sequence to make the peer0 response come back first, we'll send it
2028         # a privkey query first, and _sequence will again ensure that the
2029         # peer0 query will also come back before the others, and then
2030         # _special_answers will make sure that the query raises an exception.
2031         # The whole point of these hijinks is to exercise the code in
2032         # _privkey_query_failed. Note that the map-update will succeed, since
2033         # we'll just get a copy from one of the other shares.
2034         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2035         # Using FakeStorage._sequence means there will be read requests still
2036         # floating around.. wait for them to retire
2037         def _cancel_timer(res):
2038             if self.client._storage._pending_timer:
2039                 self.client._storage._pending_timer.cancel()
2040             return res
2041         d.addBoth(_cancel_timer)
2042         return d
2043
2044     def test_privkey_query_missing(self):
2045         # like test_privkey_query_error, but the shares are deleted by the
2046         # second query, instead of raising an exception.
2047         self.client = FakeClient(20)
2048         LARGE = "These are Larger contents" * 200 # about 5KB
2049         d = self.client.create_mutable_file(LARGE)
2050         def _created(n):
2051             self.uri = n.get_uri()
2052             self.n2 = self.client.create_node_from_uri(self.uri)
2053             return n.get_servermap(MODE_WRITE)
2054         d.addCallback(_created)
2055         d.addCallback(lambda res: fireEventually(res))
2056         def _got_smap1(smap):
2057             peer0 = list(smap.make_sharemap()[0])[0]
2058             self.client._storage._sequence = [peer0]
2059             self.client._storage._special_answers[peer0] = ["normal", "none"]
2060         d.addCallback(_got_smap1)
2061         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2062         def _cancel_timer(res):
2063             if self.client._storage._pending_timer:
2064                 self.client._storage._pending_timer.cancel()
2065             return res
2066         d.addBoth(_cancel_timer)
2067         return d