]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
more refactoring: move get_all_serverids() and get_nickname_for_serverid() from Clien...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri
8 from allmydata.storage.server import StorageServer
9 from allmydata.immutable import download
10 from allmydata.util import base32, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      FileTooLargeError, NotEnoughSharesError, IRepairResults
16 from allmydata.monitor import Monitor
17 from allmydata.test.common import ShouldFailMixin
18 from foolscap.api import eventually, fireEventually
19 from foolscap.logging import log
20 from allmydata.storage_client import StorageFarmBroker
21
22 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
23 from allmydata.mutable.common import ResponseCache, \
24      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
25      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
26      NotEnoughServersError, CorruptShareError
27 from allmydata.mutable.retrieve import Retrieve
28 from allmydata.mutable.publish import Publish
29 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
30 from allmydata.mutable.layout import unpack_header, unpack_share
31 from allmydata.mutable.repairer import MustForceRepairError
32
33 import common_util as testutil
34
35 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
36 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
37 # of this.
38
39 class FastMutableFileNode(MutableFileNode):
40     SIGNATURE_KEY_SIZE = 522
41
42 # this "FakeStorage" exists to put the share data in RAM and avoid using real
43 # network connections, both to speed up the tests and to reduce the amount of
44 # non-mutable.py code being exercised.
45
46 class FakeStorage:
47     # this class replaces the collection of storage servers, allowing the
48     # tests to examine and manipulate the published shares. It also lets us
49     # control the order in which read queries are answered, to exercise more
50     # of the error-handling code in Retrieve .
51     #
52     # Note that we ignore the storage index: this FakeStorage instance can
53     # only be used for a single storage index.
54
55
56     def __init__(self):
57         self._peers = {}
58         # _sequence is used to cause the responses to occur in a specific
59         # order. If it is in use, then we will defer queries instead of
60         # answering them right away, accumulating the Deferreds in a dict. We
61         # don't know exactly how many queries we'll get, so exactly one
62         # second after the first query arrives, we will release them all (in
63         # order).
64         self._sequence = None
65         self._pending = {}
66         self._pending_timer = None
67         self._special_answers = {}
68
69     def read(self, peerid, storage_index):
70         shares = self._peers.get(peerid, {})
71         if self._special_answers.get(peerid, []):
72             mode = self._special_answers[peerid].pop(0)
73             if mode == "fail":
74                 shares = failure.Failure(IntentionalError())
75             elif mode == "none":
76                 shares = {}
77             elif mode == "normal":
78                 pass
79         if self._sequence is None:
80             return defer.succeed(shares)
81         d = defer.Deferred()
82         if not self._pending:
83             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
84         self._pending[peerid] = (d, shares)
85         return d
86
87     def _fire_readers(self):
88         self._pending_timer = None
89         pending = self._pending
90         self._pending = {}
91         extra = []
92         for peerid in self._sequence:
93             if peerid in pending:
94                 d, shares = pending.pop(peerid)
95                 eventually(d.callback, shares)
96         for (d, shares) in pending.values():
97             eventually(d.callback, shares)
98
99     def write(self, peerid, storage_index, shnum, offset, data):
100         if peerid not in self._peers:
101             self._peers[peerid] = {}
102         shares = self._peers[peerid]
103         f = StringIO()
104         f.write(shares.get(shnum, ""))
105         f.seek(offset)
106         f.write(data)
107         shares[shnum] = f.getvalue()
108
109
110 class FakeStorageServer:
111     def __init__(self, peerid, storage):
112         self.peerid = peerid
113         self.storage = storage
114         self.queries = 0
115     def callRemote(self, methname, *args, **kwargs):
116         def _call():
117             meth = getattr(self, methname)
118             return meth(*args, **kwargs)
119         d = fireEventually()
120         d.addCallback(lambda res: _call())
121         return d
122     def callRemoteOnly(self, methname, *args, **kwargs):
123         d = self.callRemote(methname, *args, **kwargs)
124         d.addBoth(lambda ignore: None)
125         pass
126
127     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
128         pass
129
130     def slot_readv(self, storage_index, shnums, readv):
131         d = self.storage.read(self.peerid, storage_index)
132         def _read(shares):
133             response = {}
134             for shnum in shares:
135                 if shnums and shnum not in shnums:
136                     continue
137                 vector = response[shnum] = []
138                 for (offset, length) in readv:
139                     assert isinstance(offset, (int, long)), offset
140                     assert isinstance(length, (int, long)), length
141                     vector.append(shares[shnum][offset:offset+length])
142             return response
143         d.addCallback(_read)
144         return d
145
146     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
147                                         tw_vectors, read_vector):
148         # always-pass: parrot the test vectors back to them.
149         readv = {}
150         for shnum, (testv, writev, new_length) in tw_vectors.items():
151             for (offset, length, op, specimen) in testv:
152                 assert op in ("le", "eq", "ge")
153             # TODO: this isn't right, the read is controlled by read_vector,
154             # not by testv
155             readv[shnum] = [ specimen
156                              for (offset, length, op, specimen)
157                              in testv ]
158             for (offset, data) in writev:
159                 self.storage.write(self.peerid, storage_index, shnum,
160                                    offset, data)
161         answer = (True, readv)
162         return fireEventually(answer)
163
164
165 # our "FakeClient" has just enough functionality of the real Client to let
166 # the tests run.
167
168 class FakeClient:
169     mutable_file_node_class = FastMutableFileNode
170
171     def __init__(self, num_peers=10):
172         self._storage = FakeStorage()
173         self._num_peers = num_peers
174         peerids = [tagged_hash("peerid", "%d" % i)[:20]
175                    for i in range(self._num_peers)]
176         self.nodeid = "fakenodeid"
177         self.storage_broker = StorageFarmBroker()
178         for peerid in peerids:
179             fss = FakeStorageServer(peerid, self._storage)
180             self.storage_broker.add_server(peerid, fss)
181
182     def get_storage_broker(self):
183         return self.storage_broker
184     def debug_break_connection(self, peerid):
185         self.storage_broker.servers[peerid].broken = True
186     def debug_remove_connection(self, peerid):
187         self.storage_broker.servers.pop(peerid)
188     def debug_get_connection(self, peerid):
189         return self.storage_broker.servers[peerid]
190
191     def get_encoding_parameters(self):
192         return {"k": 3, "n": 10}
193
194     def log(self, msg, **kw):
195         return log.msg(msg, **kw)
196
197     def get_renewal_secret(self):
198         return "I hereby permit you to renew my files"
199     def get_cancel_secret(self):
200         return "I hereby permit you to cancel my leases"
201
202     def create_mutable_file(self, contents=""):
203         n = self.mutable_file_node_class(self)
204         d = n.create(contents)
205         d.addCallback(lambda res: n)
206         return d
207
208     def get_history(self):
209         return None
210
211     def create_node_from_uri(self, u):
212         u = IURI(u)
213         assert IMutableFileURI.providedBy(u), u
214         res = self.mutable_file_node_class(self).init_from_uri(u)
215         return res
216
217     def upload(self, uploadable):
218         assert IUploadable.providedBy(uploadable)
219         d = uploadable.get_size()
220         d.addCallback(lambda length: uploadable.read(length))
221         #d.addCallback(self.create_mutable_file)
222         def _got_data(datav):
223             data = "".join(datav)
224             #newnode = FastMutableFileNode(self)
225             return uri.LiteralFileURI(data)
226         d.addCallback(_got_data)
227         return d
228
229
230 def flip_bit(original, byte_offset):
231     return (original[:byte_offset] +
232             chr(ord(original[byte_offset]) ^ 0x01) +
233             original[byte_offset+1:])
234
235 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
236     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
237     # list of shnums to corrupt.
238     for peerid in s._peers:
239         shares = s._peers[peerid]
240         for shnum in shares:
241             if (shnums_to_corrupt is not None
242                 and shnum not in shnums_to_corrupt):
243                 continue
244             data = shares[shnum]
245             (version,
246              seqnum,
247              root_hash,
248              IV,
249              k, N, segsize, datalen,
250              o) = unpack_header(data)
251             if isinstance(offset, tuple):
252                 offset1, offset2 = offset
253             else:
254                 offset1 = offset
255                 offset2 = 0
256             if offset1 == "pubkey":
257                 real_offset = 107
258             elif offset1 in o:
259                 real_offset = o[offset1]
260             else:
261                 real_offset = offset1
262             real_offset = int(real_offset) + offset2 + offset_offset
263             assert isinstance(real_offset, int), offset
264             shares[shnum] = flip_bit(data, real_offset)
265     return res
266
267 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
268     def setUp(self):
269         self.client = FakeClient()
270
271     def test_create(self):
272         d = self.client.create_mutable_file()
273         def _created(n):
274             self.failUnless(isinstance(n, FastMutableFileNode))
275             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
276             sb = self.client.get_storage_broker()
277             peer0 = sorted(sb.get_all_serverids())[0]
278             shnums = self.client._storage._peers[peer0].keys()
279             self.failUnlessEqual(len(shnums), 1)
280         d.addCallback(_created)
281         return d
282
283     def test_serialize(self):
284         n = MutableFileNode(self.client)
285         calls = []
286         def _callback(*args, **kwargs):
287             self.failUnlessEqual(args, (4,) )
288             self.failUnlessEqual(kwargs, {"foo": 5})
289             calls.append(1)
290             return 6
291         d = n._do_serialized(_callback, 4, foo=5)
292         def _check_callback(res):
293             self.failUnlessEqual(res, 6)
294             self.failUnlessEqual(calls, [1])
295         d.addCallback(_check_callback)
296
297         def _errback():
298             raise ValueError("heya")
299         d.addCallback(lambda res:
300                       self.shouldFail(ValueError, "_check_errback", "heya",
301                                       n._do_serialized, _errback))
302         return d
303
304     def test_upload_and_download(self):
305         d = self.client.create_mutable_file()
306         def _created(n):
307             d = defer.succeed(None)
308             d.addCallback(lambda res: n.get_servermap(MODE_READ))
309             d.addCallback(lambda smap: smap.dump(StringIO()))
310             d.addCallback(lambda sio:
311                           self.failUnless("3-of-10" in sio.getvalue()))
312             d.addCallback(lambda res: n.overwrite("contents 1"))
313             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
314             d.addCallback(lambda res: n.download_best_version())
315             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
316             d.addCallback(lambda res: n.get_size_of_best_version())
317             d.addCallback(lambda size:
318                           self.failUnlessEqual(size, len("contents 1")))
319             d.addCallback(lambda res: n.overwrite("contents 2"))
320             d.addCallback(lambda res: n.download_best_version())
321             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
322             d.addCallback(lambda res: n.download(download.Data()))
323             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
324             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
325             d.addCallback(lambda smap: n.upload("contents 3", smap))
326             d.addCallback(lambda res: n.download_best_version())
327             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
328             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
329             d.addCallback(lambda smap:
330                           n.download_version(smap,
331                                              smap.best_recoverable_version()))
332             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
333             # test a file that is large enough to overcome the
334             # mapupdate-to-retrieve data caching (i.e. make the shares larger
335             # than the default readsize, which is 2000 bytes). A 15kB file
336             # will have 5kB shares.
337             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
338             d.addCallback(lambda res: n.download_best_version())
339             d.addCallback(lambda res:
340                           self.failUnlessEqual(res, "large size file" * 1000))
341             return d
342         d.addCallback(_created)
343         return d
344
345     def test_create_with_initial_contents(self):
346         d = self.client.create_mutable_file("contents 1")
347         def _created(n):
348             d = n.download_best_version()
349             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
350             d.addCallback(lambda res: n.overwrite("contents 2"))
351             d.addCallback(lambda res: n.download_best_version())
352             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
353             return d
354         d.addCallback(_created)
355         return d
356
357     def test_create_with_too_large_contents(self):
358         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
359         d = self.shouldFail(FileTooLargeError, "too_large",
360                             "SDMF is limited to one segment, and %d > %d" %
361                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
362                             self.client.create_mutable_file, BIG)
363         d.addCallback(lambda res: self.client.create_mutable_file("small"))
364         def _created(n):
365             return self.shouldFail(FileTooLargeError, "too_large_2",
366                                    "SDMF is limited to one segment, and %d > %d" %
367                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
368                                    n.overwrite, BIG)
369         d.addCallback(_created)
370         return d
371
372     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
373         d = n.get_servermap(MODE_READ)
374         d.addCallback(lambda servermap: servermap.best_recoverable_version())
375         d.addCallback(lambda verinfo:
376                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
377         return d
378
379     def test_modify(self):
380         def _modifier(old_contents, servermap, first_time):
381             return old_contents + "line2"
382         def _non_modifier(old_contents, servermap, first_time):
383             return old_contents
384         def _none_modifier(old_contents, servermap, first_time):
385             return None
386         def _error_modifier(old_contents, servermap, first_time):
387             raise ValueError("oops")
388         def _toobig_modifier(old_contents, servermap, first_time):
389             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
390         calls = []
391         def _ucw_error_modifier(old_contents, servermap, first_time):
392             # simulate an UncoordinatedWriteError once
393             calls.append(1)
394             if len(calls) <= 1:
395                 raise UncoordinatedWriteError("simulated")
396             return old_contents + "line3"
397         def _ucw_error_non_modifier(old_contents, servermap, first_time):
398             # simulate an UncoordinatedWriteError once, and don't actually
399             # modify the contents on subsequent invocations
400             calls.append(1)
401             if len(calls) <= 1:
402                 raise UncoordinatedWriteError("simulated")
403             return old_contents
404
405         d = self.client.create_mutable_file("line1")
406         def _created(n):
407             d = n.modify(_modifier)
408             d.addCallback(lambda res: n.download_best_version())
409             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
410             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
411
412             d.addCallback(lambda res: n.modify(_non_modifier))
413             d.addCallback(lambda res: n.download_best_version())
414             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
415             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
416
417             d.addCallback(lambda res: n.modify(_none_modifier))
418             d.addCallback(lambda res: n.download_best_version())
419             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
420             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
421
422             d.addCallback(lambda res:
423                           self.shouldFail(ValueError, "error_modifier", None,
424                                           n.modify, _error_modifier))
425             d.addCallback(lambda res: n.download_best_version())
426             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
427             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
428
429             d.addCallback(lambda res:
430                           self.shouldFail(FileTooLargeError, "toobig_modifier",
431                                           "SDMF is limited to one segment",
432                                           n.modify, _toobig_modifier))
433             d.addCallback(lambda res: n.download_best_version())
434             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
435             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
436
437             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
438             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
439             d.addCallback(lambda res: n.download_best_version())
440             d.addCallback(lambda res: self.failUnlessEqual(res,
441                                                            "line1line2line3"))
442             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
443
444             def _reset_ucw_error_modifier(res):
445                 calls[:] = []
446                 return res
447             d.addCallback(_reset_ucw_error_modifier)
448
449             # in practice, this n.modify call should publish twice: the first
450             # one gets a UCWE, the second does not. But our test jig (in
451             # which the modifier raises the UCWE) skips over the first one,
452             # so in this test there will be only one publish, and the seqnum
453             # will only be one larger than the previous test, not two (i.e. 4
454             # instead of 5).
455             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
456             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
457             d.addCallback(lambda res: n.download_best_version())
458             d.addCallback(lambda res: self.failUnlessEqual(res,
459                                                            "line1line2line3"))
460             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
461
462             return d
463         d.addCallback(_created)
464         return d
465
466     def test_modify_backoffer(self):
467         def _modifier(old_contents, servermap, first_time):
468             return old_contents + "line2"
469         calls = []
470         def _ucw_error_modifier(old_contents, servermap, first_time):
471             # simulate an UncoordinatedWriteError once
472             calls.append(1)
473             if len(calls) <= 1:
474                 raise UncoordinatedWriteError("simulated")
475             return old_contents + "line3"
476         def _always_ucw_error_modifier(old_contents, servermap, first_time):
477             raise UncoordinatedWriteError("simulated")
478         def _backoff_stopper(node, f):
479             return f
480         def _backoff_pauser(node, f):
481             d = defer.Deferred()
482             reactor.callLater(0.5, d.callback, None)
483             return d
484
485         # the give-up-er will hit its maximum retry count quickly
486         giveuper = BackoffAgent()
487         giveuper._delay = 0.1
488         giveuper.factor = 1
489
490         d = self.client.create_mutable_file("line1")
491         def _created(n):
492             d = n.modify(_modifier)
493             d.addCallback(lambda res: n.download_best_version())
494             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
495             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
496
497             d.addCallback(lambda res:
498                           self.shouldFail(UncoordinatedWriteError,
499                                           "_backoff_stopper", None,
500                                           n.modify, _ucw_error_modifier,
501                                           _backoff_stopper))
502             d.addCallback(lambda res: n.download_best_version())
503             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
504             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
505
506             def _reset_ucw_error_modifier(res):
507                 calls[:] = []
508                 return res
509             d.addCallback(_reset_ucw_error_modifier)
510             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
511                                                _backoff_pauser))
512             d.addCallback(lambda res: n.download_best_version())
513             d.addCallback(lambda res: self.failUnlessEqual(res,
514                                                            "line1line2line3"))
515             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
516
517             d.addCallback(lambda res:
518                           self.shouldFail(UncoordinatedWriteError,
519                                           "giveuper", None,
520                                           n.modify, _always_ucw_error_modifier,
521                                           giveuper.delay))
522             d.addCallback(lambda res: n.download_best_version())
523             d.addCallback(lambda res: self.failUnlessEqual(res,
524                                                            "line1line2line3"))
525             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
526
527             return d
528         d.addCallback(_created)
529         return d
530
531     def test_upload_and_download_full_size_keys(self):
532         self.client.mutable_file_node_class = MutableFileNode
533         d = self.client.create_mutable_file()
534         def _created(n):
535             d = defer.succeed(None)
536             d.addCallback(lambda res: n.get_servermap(MODE_READ))
537             d.addCallback(lambda smap: smap.dump(StringIO()))
538             d.addCallback(lambda sio:
539                           self.failUnless("3-of-10" in sio.getvalue()))
540             d.addCallback(lambda res: n.overwrite("contents 1"))
541             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
542             d.addCallback(lambda res: n.download_best_version())
543             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
544             d.addCallback(lambda res: n.overwrite("contents 2"))
545             d.addCallback(lambda res: n.download_best_version())
546             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
547             d.addCallback(lambda res: n.download(download.Data()))
548             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
549             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
550             d.addCallback(lambda smap: n.upload("contents 3", smap))
551             d.addCallback(lambda res: n.download_best_version())
552             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
553             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
554             d.addCallback(lambda smap:
555                           n.download_version(smap,
556                                              smap.best_recoverable_version()))
557             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
558             return d
559         d.addCallback(_created)
560         return d
561
562
563 class MakeShares(unittest.TestCase):
564     def test_encrypt(self):
565         c = FakeClient()
566         fn = FastMutableFileNode(c)
567         CONTENTS = "some initial contents"
568         d = fn.create(CONTENTS)
569         def _created(res):
570             p = Publish(fn, None)
571             p.salt = "SALT" * 4
572             p.readkey = "\x00" * 16
573             p.newdata = CONTENTS
574             p.required_shares = 3
575             p.total_shares = 10
576             p.setup_encoding_parameters()
577             return p._encrypt_and_encode()
578         d.addCallback(_created)
579         def _done(shares_and_shareids):
580             (shares, share_ids) = shares_and_shareids
581             self.failUnlessEqual(len(shares), 10)
582             for sh in shares:
583                 self.failUnless(isinstance(sh, str))
584                 self.failUnlessEqual(len(sh), 7)
585             self.failUnlessEqual(len(share_ids), 10)
586         d.addCallback(_done)
587         return d
588
589     def test_generate(self):
590         c = FakeClient()
591         fn = FastMutableFileNode(c)
592         CONTENTS = "some initial contents"
593         d = fn.create(CONTENTS)
594         def _created(res):
595             p = Publish(fn, None)
596             self._p = p
597             p.newdata = CONTENTS
598             p.required_shares = 3
599             p.total_shares = 10
600             p.setup_encoding_parameters()
601             p._new_seqnum = 3
602             p.salt = "SALT" * 4
603             # make some fake shares
604             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
605             p._privkey = fn.get_privkey()
606             p._encprivkey = fn.get_encprivkey()
607             p._pubkey = fn.get_pubkey()
608             return p._generate_shares(shares_and_ids)
609         d.addCallback(_created)
610         def _generated(res):
611             p = self._p
612             final_shares = p.shares
613             root_hash = p.root_hash
614             self.failUnlessEqual(len(root_hash), 32)
615             self.failUnless(isinstance(final_shares, dict))
616             self.failUnlessEqual(len(final_shares), 10)
617             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
618             for i,sh in final_shares.items():
619                 self.failUnless(isinstance(sh, str))
620                 # feed the share through the unpacker as a sanity-check
621                 pieces = unpack_share(sh)
622                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
623                  pubkey, signature, share_hash_chain, block_hash_tree,
624                  share_data, enc_privkey) = pieces
625                 self.failUnlessEqual(u_seqnum, 3)
626                 self.failUnlessEqual(u_root_hash, root_hash)
627                 self.failUnlessEqual(k, 3)
628                 self.failUnlessEqual(N, 10)
629                 self.failUnlessEqual(segsize, 21)
630                 self.failUnlessEqual(datalen, len(CONTENTS))
631                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
632                 sig_material = struct.pack(">BQ32s16s BBQQ",
633                                            0, p._new_seqnum, root_hash, IV,
634                                            k, N, segsize, datalen)
635                 self.failUnless(p._pubkey.verify(sig_material, signature))
636                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
637                 self.failUnless(isinstance(share_hash_chain, dict))
638                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
639                 for shnum,share_hash in share_hash_chain.items():
640                     self.failUnless(isinstance(shnum, int))
641                     self.failUnless(isinstance(share_hash, str))
642                     self.failUnlessEqual(len(share_hash), 32)
643                 self.failUnless(isinstance(block_hash_tree, list))
644                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
645                 self.failUnlessEqual(IV, "SALT"*4)
646                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
647                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
648         d.addCallback(_generated)
649         return d
650
651     # TODO: when we publish to 20 peers, we should get one share per peer on 10
652     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
653     # when we publish to zero peers, we should get a NotEnoughSharesError
654
655 class PublishMixin:
656     def publish_one(self):
657         # publish a file and create shares, which can then be manipulated
658         # later.
659         self.CONTENTS = "New contents go here" * 1000
660         num_peers = 20
661         self._client = FakeClient(num_peers)
662         self._storage = self._client._storage
663         d = self._client.create_mutable_file(self.CONTENTS)
664         def _created(node):
665             self._fn = node
666             self._fn2 = self._client.create_node_from_uri(node.get_uri())
667         d.addCallback(_created)
668         return d
669     def publish_multiple(self):
670         self.CONTENTS = ["Contents 0",
671                          "Contents 1",
672                          "Contents 2",
673                          "Contents 3a",
674                          "Contents 3b"]
675         self._copied_shares = {}
676         num_peers = 20
677         self._client = FakeClient(num_peers)
678         self._storage = self._client._storage
679         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
680         def _created(node):
681             self._fn = node
682             # now create multiple versions of the same file, and accumulate
683             # their shares, so we can mix and match them later.
684             d = defer.succeed(None)
685             d.addCallback(self._copy_shares, 0)
686             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
687             d.addCallback(self._copy_shares, 1)
688             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
689             d.addCallback(self._copy_shares, 2)
690             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
691             d.addCallback(self._copy_shares, 3)
692             # now we replace all the shares with version s3, and upload a new
693             # version to get s4b.
694             rollback = dict([(i,2) for i in range(10)])
695             d.addCallback(lambda res: self._set_versions(rollback))
696             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
697             d.addCallback(self._copy_shares, 4)
698             # we leave the storage in state 4
699             return d
700         d.addCallback(_created)
701         return d
702
703     def _copy_shares(self, ignored, index):
704         shares = self._client._storage._peers
705         # we need a deep copy
706         new_shares = {}
707         for peerid in shares:
708             new_shares[peerid] = {}
709             for shnum in shares[peerid]:
710                 new_shares[peerid][shnum] = shares[peerid][shnum]
711         self._copied_shares[index] = new_shares
712
713     def _set_versions(self, versionmap):
714         # versionmap maps shnums to which version (0,1,2,3,4) we want the
715         # share to be at. Any shnum which is left out of the map will stay at
716         # its current version.
717         shares = self._client._storage._peers
718         oldshares = self._copied_shares
719         for peerid in shares:
720             for shnum in shares[peerid]:
721                 if shnum in versionmap:
722                     index = versionmap[shnum]
723                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
724
725
726 class Servermap(unittest.TestCase, PublishMixin):
727     def setUp(self):
728         return self.publish_one()
729
730     def make_servermap(self, mode=MODE_CHECK, fn=None):
731         if fn is None:
732             fn = self._fn
733         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
734         d = smu.update()
735         return d
736
737     def update_servermap(self, oldmap, mode=MODE_CHECK):
738         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
739         d = smu.update()
740         return d
741
742     def failUnlessOneRecoverable(self, sm, num_shares):
743         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
744         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
745         best = sm.best_recoverable_version()
746         self.failIfEqual(best, None)
747         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
748         self.failUnlessEqual(len(sm.shares_available()), 1)
749         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
750         shnum, peerids = sm.make_sharemap().items()[0]
751         peerid = list(peerids)[0]
752         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
753         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
754         return sm
755
756     def test_basic(self):
757         d = defer.succeed(None)
758         ms = self.make_servermap
759         us = self.update_servermap
760
761         d.addCallback(lambda res: ms(mode=MODE_CHECK))
762         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
763         d.addCallback(lambda res: ms(mode=MODE_WRITE))
764         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
765         d.addCallback(lambda res: ms(mode=MODE_READ))
766         # this more stops at k+epsilon, and epsilon=k, so 6 shares
767         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
768         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
769         # this mode stops at 'k' shares
770         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
771
772         # and can we re-use the same servermap? Note that these are sorted in
773         # increasing order of number of servers queried, since once a server
774         # gets into the servermap, we'll always ask it for an update.
775         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
776         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
777         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
778         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
779         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
780         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
781         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
782         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
783         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
784
785         return d
786
787     def test_fetch_privkey(self):
788         d = defer.succeed(None)
789         # use the sibling filenode (which hasn't been used yet), and make
790         # sure it can fetch the privkey. The file is small, so the privkey
791         # will be fetched on the first (query) pass.
792         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
793         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
794
795         # create a new file, which is large enough to knock the privkey out
796         # of the early part of the file
797         LARGE = "These are Larger contents" * 200 # about 5KB
798         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
799         def _created(large_fn):
800             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
801             return self.make_servermap(MODE_WRITE, large_fn2)
802         d.addCallback(_created)
803         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
804         return d
805
806     def test_mark_bad(self):
807         d = defer.succeed(None)
808         ms = self.make_servermap
809         us = self.update_servermap
810
811         d.addCallback(lambda res: ms(mode=MODE_READ))
812         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
813         def _made_map(sm):
814             v = sm.best_recoverable_version()
815             vm = sm.make_versionmap()
816             shares = list(vm[v])
817             self.failUnlessEqual(len(shares), 6)
818             self._corrupted = set()
819             # mark the first 5 shares as corrupt, then update the servermap.
820             # The map should not have the marked shares it in any more, and
821             # new shares should be found to replace the missing ones.
822             for (shnum, peerid, timestamp) in shares:
823                 if shnum < 5:
824                     self._corrupted.add( (peerid, shnum) )
825                     sm.mark_bad_share(peerid, shnum, "")
826             return self.update_servermap(sm, MODE_WRITE)
827         d.addCallback(_made_map)
828         def _check_map(sm):
829             # this should find all 5 shares that weren't marked bad
830             v = sm.best_recoverable_version()
831             vm = sm.make_versionmap()
832             shares = list(vm[v])
833             for (peerid, shnum) in self._corrupted:
834                 peer_shares = sm.shares_on_peer(peerid)
835                 self.failIf(shnum in peer_shares,
836                             "%d was in %s" % (shnum, peer_shares))
837             self.failUnlessEqual(len(shares), 5)
838         d.addCallback(_check_map)
839         return d
840
841     def failUnlessNoneRecoverable(self, sm):
842         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
843         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
844         best = sm.best_recoverable_version()
845         self.failUnlessEqual(best, None)
846         self.failUnlessEqual(len(sm.shares_available()), 0)
847
848     def test_no_shares(self):
849         self._client._storage._peers = {} # delete all shares
850         ms = self.make_servermap
851         d = defer.succeed(None)
852
853         d.addCallback(lambda res: ms(mode=MODE_CHECK))
854         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
855
856         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
857         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
858
859         d.addCallback(lambda res: ms(mode=MODE_WRITE))
860         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
861
862         d.addCallback(lambda res: ms(mode=MODE_READ))
863         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
864
865         return d
866
867     def failUnlessNotQuiteEnough(self, sm):
868         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
869         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
870         best = sm.best_recoverable_version()
871         self.failUnlessEqual(best, None)
872         self.failUnlessEqual(len(sm.shares_available()), 1)
873         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
874         return sm
875
876     def test_not_quite_enough_shares(self):
877         s = self._client._storage
878         ms = self.make_servermap
879         num_shares = len(s._peers)
880         for peerid in s._peers:
881             s._peers[peerid] = {}
882             num_shares -= 1
883             if num_shares == 2:
884                 break
885         # now there ought to be only two shares left
886         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
887
888         d = defer.succeed(None)
889
890         d.addCallback(lambda res: ms(mode=MODE_CHECK))
891         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
892         d.addCallback(lambda sm:
893                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
894         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
895         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
896         d.addCallback(lambda res: ms(mode=MODE_WRITE))
897         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
898         d.addCallback(lambda res: ms(mode=MODE_READ))
899         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
900
901         return d
902
903
904
905 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
906     def setUp(self):
907         return self.publish_one()
908
909     def make_servermap(self, mode=MODE_READ, oldmap=None):
910         if oldmap is None:
911             oldmap = ServerMap()
912         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
913         d = smu.update()
914         return d
915
916     def abbrev_verinfo(self, verinfo):
917         if verinfo is None:
918             return None
919         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
920          offsets_tuple) = verinfo
921         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
922
923     def abbrev_verinfo_dict(self, verinfo_d):
924         output = {}
925         for verinfo,value in verinfo_d.items():
926             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
927              offsets_tuple) = verinfo
928             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
929         return output
930
931     def dump_servermap(self, servermap):
932         print "SERVERMAP", servermap
933         print "RECOVERABLE", [self.abbrev_verinfo(v)
934                               for v in servermap.recoverable_versions()]
935         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
936         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
937
938     def do_download(self, servermap, version=None):
939         if version is None:
940             version = servermap.best_recoverable_version()
941         r = Retrieve(self._fn, servermap, version)
942         return r.download()
943
944     def test_basic(self):
945         d = self.make_servermap()
946         def _do_retrieve(servermap):
947             self._smap = servermap
948             #self.dump_servermap(servermap)
949             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
950             return self.do_download(servermap)
951         d.addCallback(_do_retrieve)
952         def _retrieved(new_contents):
953             self.failUnlessEqual(new_contents, self.CONTENTS)
954         d.addCallback(_retrieved)
955         # we should be able to re-use the same servermap, both with and
956         # without updating it.
957         d.addCallback(lambda res: self.do_download(self._smap))
958         d.addCallback(_retrieved)
959         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
960         d.addCallback(lambda res: self.do_download(self._smap))
961         d.addCallback(_retrieved)
962         # clobbering the pubkey should make the servermap updater re-fetch it
963         def _clobber_pubkey(res):
964             self._fn._pubkey = None
965         d.addCallback(_clobber_pubkey)
966         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
967         d.addCallback(lambda res: self.do_download(self._smap))
968         d.addCallback(_retrieved)
969         return d
970
971     def test_all_shares_vanished(self):
972         d = self.make_servermap()
973         def _remove_shares(servermap):
974             for shares in self._storage._peers.values():
975                 shares.clear()
976             d1 = self.shouldFail(NotEnoughSharesError,
977                                  "test_all_shares_vanished",
978                                  "ran out of peers",
979                                  self.do_download, servermap)
980             return d1
981         d.addCallback(_remove_shares)
982         return d
983
984     def test_no_servers(self):
985         c2 = FakeClient(0)
986         self._fn._client = c2
987         # if there are no servers, then a MODE_READ servermap should come
988         # back empty
989         d = self.make_servermap()
990         def _check_servermap(servermap):
991             self.failUnlessEqual(servermap.best_recoverable_version(), None)
992             self.failIf(servermap.recoverable_versions())
993             self.failIf(servermap.unrecoverable_versions())
994             self.failIf(servermap.all_peers())
995         d.addCallback(_check_servermap)
996         return d
997     test_no_servers.timeout = 15
998
999     def test_no_servers_download(self):
1000         c2 = FakeClient(0)
1001         self._fn._client = c2
1002         d = self.shouldFail(UnrecoverableFileError,
1003                             "test_no_servers_download",
1004                             "no recoverable versions",
1005                             self._fn.download_best_version)
1006         def _restore(res):
1007             # a failed download that occurs while we aren't connected to
1008             # anybody should not prevent a subsequent download from working.
1009             # This isn't quite the webapi-driven test that #463 wants, but it
1010             # should be close enough.
1011             self._fn._client = self._client
1012             return self._fn.download_best_version()
1013         def _retrieved(new_contents):
1014             self.failUnlessEqual(new_contents, self.CONTENTS)
1015         d.addCallback(_restore)
1016         d.addCallback(_retrieved)
1017         return d
1018     test_no_servers_download.timeout = 15
1019
1020     def _test_corrupt_all(self, offset, substring,
1021                           should_succeed=False, corrupt_early=True,
1022                           failure_checker=None):
1023         d = defer.succeed(None)
1024         if corrupt_early:
1025             d.addCallback(corrupt, self._storage, offset)
1026         d.addCallback(lambda res: self.make_servermap())
1027         if not corrupt_early:
1028             d.addCallback(corrupt, self._storage, offset)
1029         def _do_retrieve(servermap):
1030             ver = servermap.best_recoverable_version()
1031             if ver is None and not should_succeed:
1032                 # no recoverable versions == not succeeding. The problem
1033                 # should be noted in the servermap's list of problems.
1034                 if substring:
1035                     allproblems = [str(f) for f in servermap.problems]
1036                     self.failUnless(substring in "".join(allproblems))
1037                 return servermap
1038             if should_succeed:
1039                 d1 = self._fn.download_version(servermap, ver)
1040                 d1.addCallback(lambda new_contents:
1041                                self.failUnlessEqual(new_contents, self.CONTENTS))
1042             else:
1043                 d1 = self.shouldFail(NotEnoughSharesError,
1044                                      "_corrupt_all(offset=%s)" % (offset,),
1045                                      substring,
1046                                      self._fn.download_version, servermap, ver)
1047             if failure_checker:
1048                 d1.addCallback(failure_checker)
1049             d1.addCallback(lambda res: servermap)
1050             return d1
1051         d.addCallback(_do_retrieve)
1052         return d
1053
1054     def test_corrupt_all_verbyte(self):
1055         # when the version byte is not 0, we hit an assertion error in
1056         # unpack_share().
1057         d = self._test_corrupt_all(0, "AssertionError")
1058         def _check_servermap(servermap):
1059             # and the dump should mention the problems
1060             s = StringIO()
1061             dump = servermap.dump(s).getvalue()
1062             self.failUnless("10 PROBLEMS" in dump, dump)
1063         d.addCallback(_check_servermap)
1064         return d
1065
1066     def test_corrupt_all_seqnum(self):
1067         # a corrupt sequence number will trigger a bad signature
1068         return self._test_corrupt_all(1, "signature is invalid")
1069
1070     def test_corrupt_all_R(self):
1071         # a corrupt root hash will trigger a bad signature
1072         return self._test_corrupt_all(9, "signature is invalid")
1073
1074     def test_corrupt_all_IV(self):
1075         # a corrupt salt/IV will trigger a bad signature
1076         return self._test_corrupt_all(41, "signature is invalid")
1077
1078     def test_corrupt_all_k(self):
1079         # a corrupt 'k' will trigger a bad signature
1080         return self._test_corrupt_all(57, "signature is invalid")
1081
1082     def test_corrupt_all_N(self):
1083         # a corrupt 'N' will trigger a bad signature
1084         return self._test_corrupt_all(58, "signature is invalid")
1085
1086     def test_corrupt_all_segsize(self):
1087         # a corrupt segsize will trigger a bad signature
1088         return self._test_corrupt_all(59, "signature is invalid")
1089
1090     def test_corrupt_all_datalen(self):
1091         # a corrupt data length will trigger a bad signature
1092         return self._test_corrupt_all(67, "signature is invalid")
1093
1094     def test_corrupt_all_pubkey(self):
1095         # a corrupt pubkey won't match the URI's fingerprint. We need to
1096         # remove the pubkey from the filenode, or else it won't bother trying
1097         # to update it.
1098         self._fn._pubkey = None
1099         return self._test_corrupt_all("pubkey",
1100                                       "pubkey doesn't match fingerprint")
1101
1102     def test_corrupt_all_sig(self):
1103         # a corrupt signature is a bad one
1104         # the signature runs from about [543:799], depending upon the length
1105         # of the pubkey
1106         return self._test_corrupt_all("signature", "signature is invalid")
1107
1108     def test_corrupt_all_share_hash_chain_number(self):
1109         # a corrupt share hash chain entry will show up as a bad hash. If we
1110         # mangle the first byte, that will look like a bad hash number,
1111         # causing an IndexError
1112         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1113
1114     def test_corrupt_all_share_hash_chain_hash(self):
1115         # a corrupt share hash chain entry will show up as a bad hash. If we
1116         # mangle a few bytes in, that will look like a bad hash.
1117         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1118
1119     def test_corrupt_all_block_hash_tree(self):
1120         return self._test_corrupt_all("block_hash_tree",
1121                                       "block hash tree failure")
1122
1123     def test_corrupt_all_block(self):
1124         return self._test_corrupt_all("share_data", "block hash tree failure")
1125
1126     def test_corrupt_all_encprivkey(self):
1127         # a corrupted privkey won't even be noticed by the reader, only by a
1128         # writer.
1129         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1130
1131
1132     def test_corrupt_all_seqnum_late(self):
1133         # corrupting the seqnum between mapupdate and retrieve should result
1134         # in NotEnoughSharesError, since each share will look invalid
1135         def _check(res):
1136             f = res[0]
1137             self.failUnless(f.check(NotEnoughSharesError))
1138             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1139         return self._test_corrupt_all(1, "ran out of peers",
1140                                       corrupt_early=False,
1141                                       failure_checker=_check)
1142
1143     def test_corrupt_all_block_hash_tree_late(self):
1144         def _check(res):
1145             f = res[0]
1146             self.failUnless(f.check(NotEnoughSharesError))
1147         return self._test_corrupt_all("block_hash_tree",
1148                                       "block hash tree failure",
1149                                       corrupt_early=False,
1150                                       failure_checker=_check)
1151
1152
1153     def test_corrupt_all_block_late(self):
1154         def _check(res):
1155             f = res[0]
1156             self.failUnless(f.check(NotEnoughSharesError))
1157         return self._test_corrupt_all("share_data", "block hash tree failure",
1158                                       corrupt_early=False,
1159                                       failure_checker=_check)
1160
1161
1162     def test_basic_pubkey_at_end(self):
1163         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1164         # download to succeed but forcing a bunch of retries first. Note that
1165         # this is rather pessimistic: our Retrieve process will throw away
1166         # the whole share if the pubkey is bad, even though the rest of the
1167         # share might be good.
1168
1169         self._fn._pubkey = None
1170         k = self._fn.get_required_shares()
1171         N = self._fn.get_total_shares()
1172         d = defer.succeed(None)
1173         d.addCallback(corrupt, self._storage, "pubkey",
1174                       shnums_to_corrupt=range(0, N-k))
1175         d.addCallback(lambda res: self.make_servermap())
1176         def _do_retrieve(servermap):
1177             self.failUnless(servermap.problems)
1178             self.failUnless("pubkey doesn't match fingerprint"
1179                             in str(servermap.problems[0]))
1180             ver = servermap.best_recoverable_version()
1181             r = Retrieve(self._fn, servermap, ver)
1182             return r.download()
1183         d.addCallback(_do_retrieve)
1184         d.addCallback(lambda new_contents:
1185                       self.failUnlessEqual(new_contents, self.CONTENTS))
1186         return d
1187
1188     def test_corrupt_some(self):
1189         # corrupt the data of first five shares (so the servermap thinks
1190         # they're good but retrieve marks them as bad), so that the
1191         # MODE_READ set of 6 will be insufficient, forcing node.download to
1192         # retry with more servers.
1193         corrupt(None, self._storage, "share_data", range(5))
1194         d = self.make_servermap()
1195         def _do_retrieve(servermap):
1196             ver = servermap.best_recoverable_version()
1197             self.failUnless(ver)
1198             return self._fn.download_best_version()
1199         d.addCallback(_do_retrieve)
1200         d.addCallback(lambda new_contents:
1201                       self.failUnlessEqual(new_contents, self.CONTENTS))
1202         return d
1203
1204     def test_download_fails(self):
1205         corrupt(None, self._storage, "signature")
1206         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1207                             "no recoverable versions",
1208                             self._fn.download_best_version)
1209         return d
1210
1211
1212 class CheckerMixin:
1213     def check_good(self, r, where):
1214         self.failUnless(r.is_healthy(), where)
1215         return r
1216
1217     def check_bad(self, r, where):
1218         self.failIf(r.is_healthy(), where)
1219         return r
1220
1221     def check_expected_failure(self, r, expected_exception, substring, where):
1222         for (peerid, storage_index, shnum, f) in r.problems:
1223             if f.check(expected_exception):
1224                 self.failUnless(substring in str(f),
1225                                 "%s: substring '%s' not in '%s'" %
1226                                 (where, substring, str(f)))
1227                 return
1228         self.fail("%s: didn't see expected exception %s in problems %s" %
1229                   (where, expected_exception, r.problems))
1230
1231
1232 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1233     def setUp(self):
1234         return self.publish_one()
1235
1236
1237     def test_check_good(self):
1238         d = self._fn.check(Monitor())
1239         d.addCallback(self.check_good, "test_check_good")
1240         return d
1241
1242     def test_check_no_shares(self):
1243         for shares in self._storage._peers.values():
1244             shares.clear()
1245         d = self._fn.check(Monitor())
1246         d.addCallback(self.check_bad, "test_check_no_shares")
1247         return d
1248
1249     def test_check_not_enough_shares(self):
1250         for shares in self._storage._peers.values():
1251             for shnum in shares.keys():
1252                 if shnum > 0:
1253                     del shares[shnum]
1254         d = self._fn.check(Monitor())
1255         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1256         return d
1257
1258     def test_check_all_bad_sig(self):
1259         corrupt(None, self._storage, 1) # bad sig
1260         d = self._fn.check(Monitor())
1261         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1262         return d
1263
1264     def test_check_all_bad_blocks(self):
1265         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1266         # the Checker won't notice this.. it doesn't look at actual data
1267         d = self._fn.check(Monitor())
1268         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1269         return d
1270
1271     def test_verify_good(self):
1272         d = self._fn.check(Monitor(), verify=True)
1273         d.addCallback(self.check_good, "test_verify_good")
1274         return d
1275
1276     def test_verify_all_bad_sig(self):
1277         corrupt(None, self._storage, 1) # bad sig
1278         d = self._fn.check(Monitor(), verify=True)
1279         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1280         return d
1281
1282     def test_verify_one_bad_sig(self):
1283         corrupt(None, self._storage, 1, [9]) # bad sig
1284         d = self._fn.check(Monitor(), verify=True)
1285         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1286         return d
1287
1288     def test_verify_one_bad_block(self):
1289         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1290         # the Verifier *will* notice this, since it examines every byte
1291         d = self._fn.check(Monitor(), verify=True)
1292         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1293         d.addCallback(self.check_expected_failure,
1294                       CorruptShareError, "block hash tree failure",
1295                       "test_verify_one_bad_block")
1296         return d
1297
1298     def test_verify_one_bad_sharehash(self):
1299         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1300         d = self._fn.check(Monitor(), verify=True)
1301         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1302         d.addCallback(self.check_expected_failure,
1303                       CorruptShareError, "corrupt hashes",
1304                       "test_verify_one_bad_sharehash")
1305         return d
1306
1307     def test_verify_one_bad_encprivkey(self):
1308         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1309         d = self._fn.check(Monitor(), verify=True)
1310         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1311         d.addCallback(self.check_expected_failure,
1312                       CorruptShareError, "invalid privkey",
1313                       "test_verify_one_bad_encprivkey")
1314         return d
1315
1316     def test_verify_one_bad_encprivkey_uncheckable(self):
1317         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1318         readonly_fn = self._fn.get_readonly()
1319         # a read-only node has no way to validate the privkey
1320         d = readonly_fn.check(Monitor(), verify=True)
1321         d.addCallback(self.check_good,
1322                       "test_verify_one_bad_encprivkey_uncheckable")
1323         return d
1324
1325 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1326
1327     def get_shares(self, s):
1328         all_shares = {} # maps (peerid, shnum) to share data
1329         for peerid in s._peers:
1330             shares = s._peers[peerid]
1331             for shnum in shares:
1332                 data = shares[shnum]
1333                 all_shares[ (peerid, shnum) ] = data
1334         return all_shares
1335
1336     def copy_shares(self, ignored=None):
1337         self.old_shares.append(self.get_shares(self._storage))
1338
1339     def test_repair_nop(self):
1340         self.old_shares = []
1341         d = self.publish_one()
1342         d.addCallback(self.copy_shares)
1343         d.addCallback(lambda res: self._fn.check(Monitor()))
1344         d.addCallback(lambda check_results: self._fn.repair(check_results))
1345         def _check_results(rres):
1346             self.failUnless(IRepairResults.providedBy(rres))
1347             # TODO: examine results
1348
1349             self.copy_shares()
1350
1351             initial_shares = self.old_shares[0]
1352             new_shares = self.old_shares[1]
1353             # TODO: this really shouldn't change anything. When we implement
1354             # a "minimal-bandwidth" repairer", change this test to assert:
1355             #self.failUnlessEqual(new_shares, initial_shares)
1356
1357             # all shares should be in the same place as before
1358             self.failUnlessEqual(set(initial_shares.keys()),
1359                                  set(new_shares.keys()))
1360             # but they should all be at a newer seqnum. The IV will be
1361             # different, so the roothash will be too.
1362             for key in initial_shares:
1363                 (version0,
1364                  seqnum0,
1365                  root_hash0,
1366                  IV0,
1367                  k0, N0, segsize0, datalen0,
1368                  o0) = unpack_header(initial_shares[key])
1369                 (version1,
1370                  seqnum1,
1371                  root_hash1,
1372                  IV1,
1373                  k1, N1, segsize1, datalen1,
1374                  o1) = unpack_header(new_shares[key])
1375                 self.failUnlessEqual(version0, version1)
1376                 self.failUnlessEqual(seqnum0+1, seqnum1)
1377                 self.failUnlessEqual(k0, k1)
1378                 self.failUnlessEqual(N0, N1)
1379                 self.failUnlessEqual(segsize0, segsize1)
1380                 self.failUnlessEqual(datalen0, datalen1)
1381         d.addCallback(_check_results)
1382         return d
1383
1384     def failIfSharesChanged(self, ignored=None):
1385         old_shares = self.old_shares[-2]
1386         current_shares = self.old_shares[-1]
1387         self.failUnlessEqual(old_shares, current_shares)
1388
1389     def test_merge(self):
1390         self.old_shares = []
1391         d = self.publish_multiple()
1392         # repair will refuse to merge multiple highest seqnums unless you
1393         # pass force=True
1394         d.addCallback(lambda res:
1395                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1396                                           1:4,3:4,5:4,7:4,9:4}))
1397         d.addCallback(self.copy_shares)
1398         d.addCallback(lambda res: self._fn.check(Monitor()))
1399         def _try_repair(check_results):
1400             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1401             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1402                                  self._fn.repair, check_results)
1403             d2.addCallback(self.copy_shares)
1404             d2.addCallback(self.failIfSharesChanged)
1405             d2.addCallback(lambda res: check_results)
1406             return d2
1407         d.addCallback(_try_repair)
1408         d.addCallback(lambda check_results:
1409                       self._fn.repair(check_results, force=True))
1410         # this should give us 10 shares of the highest roothash
1411         def _check_repair_results(rres):
1412             pass # TODO
1413         d.addCallback(_check_repair_results)
1414         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1415         def _check_smap(smap):
1416             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1417             self.failIf(smap.unrecoverable_versions())
1418             # now, which should have won?
1419             roothash_s4a = self.get_roothash_for(3)
1420             roothash_s4b = self.get_roothash_for(4)
1421             if roothash_s4b > roothash_s4a:
1422                 expected_contents = self.CONTENTS[4]
1423             else:
1424                 expected_contents = self.CONTENTS[3]
1425             new_versionid = smap.best_recoverable_version()
1426             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1427             d2 = self._fn.download_version(smap, new_versionid)
1428             d2.addCallback(self.failUnlessEqual, expected_contents)
1429             return d2
1430         d.addCallback(_check_smap)
1431         return d
1432
1433     def test_non_merge(self):
1434         self.old_shares = []
1435         d = self.publish_multiple()
1436         # repair should not refuse a repair that doesn't need to merge. In
1437         # this case, we combine v2 with v3. The repair should ignore v2 and
1438         # copy v3 into a new v5.
1439         d.addCallback(lambda res:
1440                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1441                                           1:3,3:3,5:3,7:3,9:3}))
1442         d.addCallback(lambda res: self._fn.check(Monitor()))
1443         d.addCallback(lambda check_results: self._fn.repair(check_results))
1444         # this should give us 10 shares of v3
1445         def _check_repair_results(rres):
1446             pass # TODO
1447         d.addCallback(_check_repair_results)
1448         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1449         def _check_smap(smap):
1450             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1451             self.failIf(smap.unrecoverable_versions())
1452             # now, which should have won?
1453             roothash_s4a = self.get_roothash_for(3)
1454             expected_contents = self.CONTENTS[3]
1455             new_versionid = smap.best_recoverable_version()
1456             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1457             d2 = self._fn.download_version(smap, new_versionid)
1458             d2.addCallback(self.failUnlessEqual, expected_contents)
1459             return d2
1460         d.addCallback(_check_smap)
1461         return d
1462
1463     def get_roothash_for(self, index):
1464         # return the roothash for the first share we see in the saved set
1465         shares = self._copied_shares[index]
1466         for peerid in shares:
1467             for shnum in shares[peerid]:
1468                 share = shares[peerid][shnum]
1469                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1470                           unpack_header(share)
1471                 return root_hash
1472
1473 class MultipleEncodings(unittest.TestCase):
1474     def setUp(self):
1475         self.CONTENTS = "New contents go here"
1476         num_peers = 20
1477         self._client = FakeClient(num_peers)
1478         self._storage = self._client._storage
1479         d = self._client.create_mutable_file(self.CONTENTS)
1480         def _created(node):
1481             self._fn = node
1482         d.addCallback(_created)
1483         return d
1484
1485     def _encode(self, k, n, data):
1486         # encode 'data' into a peerid->shares dict.
1487
1488         fn2 = FastMutableFileNode(self._client)
1489         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1490         # and _fingerprint
1491         fn = self._fn
1492         fn2.init_from_uri(fn.get_uri())
1493         # then we copy over other fields that are normally fetched from the
1494         # existing shares
1495         fn2._pubkey = fn._pubkey
1496         fn2._privkey = fn._privkey
1497         fn2._encprivkey = fn._encprivkey
1498         # and set the encoding parameters to something completely different
1499         fn2._required_shares = k
1500         fn2._total_shares = n
1501
1502         s = self._client._storage
1503         s._peers = {} # clear existing storage
1504         p2 = Publish(fn2, None)
1505         d = p2.publish(data)
1506         def _published(res):
1507             shares = s._peers
1508             s._peers = {}
1509             return shares
1510         d.addCallback(_published)
1511         return d
1512
1513     def make_servermap(self, mode=MODE_READ, oldmap=None):
1514         if oldmap is None:
1515             oldmap = ServerMap()
1516         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1517         d = smu.update()
1518         return d
1519
1520     def test_multiple_encodings(self):
1521         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1522         # then mix up the shares, to make sure that download survives seeing
1523         # a variety of encodings. This is actually kind of tricky to set up.
1524
1525         contents1 = "Contents for encoding 1 (3-of-10) go here"
1526         contents2 = "Contents for encoding 2 (4-of-9) go here"
1527         contents3 = "Contents for encoding 3 (4-of-7) go here"
1528
1529         # we make a retrieval object that doesn't know what encoding
1530         # parameters to use
1531         fn3 = FastMutableFileNode(self._client)
1532         fn3.init_from_uri(self._fn.get_uri())
1533
1534         # now we upload a file through fn1, and grab its shares
1535         d = self._encode(3, 10, contents1)
1536         def _encoded_1(shares):
1537             self._shares1 = shares
1538         d.addCallback(_encoded_1)
1539         d.addCallback(lambda res: self._encode(4, 9, contents2))
1540         def _encoded_2(shares):
1541             self._shares2 = shares
1542         d.addCallback(_encoded_2)
1543         d.addCallback(lambda res: self._encode(4, 7, contents3))
1544         def _encoded_3(shares):
1545             self._shares3 = shares
1546         d.addCallback(_encoded_3)
1547
1548         def _merge(res):
1549             log.msg("merging sharelists")
1550             # we merge the shares from the two sets, leaving each shnum in
1551             # its original location, but using a share from set1 or set2
1552             # according to the following sequence:
1553             #
1554             #  4-of-9  a  s2
1555             #  4-of-9  b  s2
1556             #  4-of-7  c   s3
1557             #  4-of-9  d  s2
1558             #  3-of-9  e s1
1559             #  3-of-9  f s1
1560             #  3-of-9  g s1
1561             #  4-of-9  h  s2
1562             #
1563             # so that neither form can be recovered until fetch [f], at which
1564             # point version-s1 (the 3-of-10 form) should be recoverable. If
1565             # the implementation latches on to the first version it sees,
1566             # then s2 will be recoverable at fetch [g].
1567
1568             # Later, when we implement code that handles multiple versions,
1569             # we can use this framework to assert that all recoverable
1570             # versions are retrieved, and test that 'epsilon' does its job
1571
1572             places = [2, 2, 3, 2, 1, 1, 1, 2]
1573
1574             sharemap = {}
1575             sb = self._client.get_storage_broker()
1576
1577             for i,peerid in enumerate(sb.get_all_serverids()):
1578                 peerid_s = shortnodeid_b2a(peerid)
1579                 for shnum in self._shares1.get(peerid, {}):
1580                     if shnum < len(places):
1581                         which = places[shnum]
1582                     else:
1583                         which = "x"
1584                     self._client._storage._peers[peerid] = peers = {}
1585                     in_1 = shnum in self._shares1[peerid]
1586                     in_2 = shnum in self._shares2.get(peerid, {})
1587                     in_3 = shnum in self._shares3.get(peerid, {})
1588                     #print peerid_s, shnum, which, in_1, in_2, in_3
1589                     if which == 1:
1590                         if in_1:
1591                             peers[shnum] = self._shares1[peerid][shnum]
1592                             sharemap[shnum] = peerid
1593                     elif which == 2:
1594                         if in_2:
1595                             peers[shnum] = self._shares2[peerid][shnum]
1596                             sharemap[shnum] = peerid
1597                     elif which == 3:
1598                         if in_3:
1599                             peers[shnum] = self._shares3[peerid][shnum]
1600                             sharemap[shnum] = peerid
1601
1602             # we don't bother placing any other shares
1603             # now sort the sequence so that share 0 is returned first
1604             new_sequence = [sharemap[shnum]
1605                             for shnum in sorted(sharemap.keys())]
1606             self._client._storage._sequence = new_sequence
1607             log.msg("merge done")
1608         d.addCallback(_merge)
1609         d.addCallback(lambda res: fn3.download_best_version())
1610         def _retrieved(new_contents):
1611             # the current specified behavior is "first version recoverable"
1612             self.failUnlessEqual(new_contents, contents1)
1613         d.addCallback(_retrieved)
1614         return d
1615
1616
1617 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1618
1619     def setUp(self):
1620         return self.publish_multiple()
1621
1622     def test_multiple_versions(self):
1623         # if we see a mix of versions in the grid, download_best_version
1624         # should get the latest one
1625         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1626         d = self._fn.download_best_version()
1627         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1628         # and the checker should report problems
1629         d.addCallback(lambda res: self._fn.check(Monitor()))
1630         d.addCallback(self.check_bad, "test_multiple_versions")
1631
1632         # but if everything is at version 2, that's what we should download
1633         d.addCallback(lambda res:
1634                       self._set_versions(dict([(i,2) for i in range(10)])))
1635         d.addCallback(lambda res: self._fn.download_best_version())
1636         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1637         # if exactly one share is at version 3, we should still get v2
1638         d.addCallback(lambda res:
1639                       self._set_versions({0:3}))
1640         d.addCallback(lambda res: self._fn.download_best_version())
1641         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1642         # but the servermap should see the unrecoverable version. This
1643         # depends upon the single newer share being queried early.
1644         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1645         def _check_smap(smap):
1646             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1647             newer = smap.unrecoverable_newer_versions()
1648             self.failUnlessEqual(len(newer), 1)
1649             verinfo, health = newer.items()[0]
1650             self.failUnlessEqual(verinfo[0], 4)
1651             self.failUnlessEqual(health, (1,3))
1652             self.failIf(smap.needs_merge())
1653         d.addCallback(_check_smap)
1654         # if we have a mix of two parallel versions (s4a and s4b), we could
1655         # recover either
1656         d.addCallback(lambda res:
1657                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1658                                           1:4,3:4,5:4,7:4,9:4}))
1659         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1660         def _check_smap_mixed(smap):
1661             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1662             newer = smap.unrecoverable_newer_versions()
1663             self.failUnlessEqual(len(newer), 0)
1664             self.failUnless(smap.needs_merge())
1665         d.addCallback(_check_smap_mixed)
1666         d.addCallback(lambda res: self._fn.download_best_version())
1667         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1668                                                   res == self.CONTENTS[4]))
1669         return d
1670
1671     def test_replace(self):
1672         # if we see a mix of versions in the grid, we should be able to
1673         # replace them all with a newer version
1674
1675         # if exactly one share is at version 3, we should download (and
1676         # replace) v2, and the result should be v4. Note that the index we
1677         # give to _set_versions is different than the sequence number.
1678         target = dict([(i,2) for i in range(10)]) # seqnum3
1679         target[0] = 3 # seqnum4
1680         self._set_versions(target)
1681
1682         def _modify(oldversion, servermap, first_time):
1683             return oldversion + " modified"
1684         d = self._fn.modify(_modify)
1685         d.addCallback(lambda res: self._fn.download_best_version())
1686         expected = self.CONTENTS[2] + " modified"
1687         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1688         # and the servermap should indicate that the outlier was replaced too
1689         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1690         def _check_smap(smap):
1691             self.failUnlessEqual(smap.highest_seqnum(), 5)
1692             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1693             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1694         d.addCallback(_check_smap)
1695         return d
1696
1697
1698 class Utils(unittest.TestCase):
1699     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1700         # we compare this against sets of integers
1701         x = set(range(x_start, x_start+x_length))
1702         y = set(range(y_start, y_start+y_length))
1703         should_be_inside = x.issubset(y)
1704         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1705                                                          y_start, y_length),
1706                              str((x_start, x_length, y_start, y_length)))
1707
1708     def test_cache_inside(self):
1709         c = ResponseCache()
1710         x_start = 10
1711         x_length = 5
1712         for y_start in range(8, 17):
1713             for y_length in range(8):
1714                 self._do_inside(c, x_start, x_length, y_start, y_length)
1715
1716     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1717         # we compare this against sets of integers
1718         x = set(range(x_start, x_start+x_length))
1719         y = set(range(y_start, y_start+y_length))
1720         overlap = bool(x.intersection(y))
1721         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1722                                                       y_start, y_length),
1723                              str((x_start, x_length, y_start, y_length)))
1724
1725     def test_cache_overlap(self):
1726         c = ResponseCache()
1727         x_start = 10
1728         x_length = 5
1729         for y_start in range(8, 17):
1730             for y_length in range(8):
1731                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1732
1733     def test_cache(self):
1734         c = ResponseCache()
1735         # xdata = base62.b2a(os.urandom(100))[:100]
1736         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1737         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1738         nope = (None, None)
1739         c.add("v1", 1, 0, xdata, "time0")
1740         c.add("v1", 1, 2000, ydata, "time1")
1741         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1742         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1743         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1744         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1745         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1747         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1748         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1749         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1750         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1752         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1753         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1754         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1755         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1756         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1757         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1758         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1759
1760         # optional: join fragments
1761         c = ResponseCache()
1762         c.add("v1", 1, 0, xdata[:10], "time0")
1763         c.add("v1", 1, 10, xdata[10:20], "time1")
1764         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1765
1766 class Exceptions(unittest.TestCase):
1767     def test_repr(self):
1768         nmde = NeedMoreDataError(100, 50, 100)
1769         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1770         ucwe = UncoordinatedWriteError()
1771         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1772
1773 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1774 # instances which always succeed. So we need a less-fake one.
1775
1776 class IntentionalError(Exception):
1777     pass
1778
1779 class LocalWrapper:
1780     def __init__(self, original):
1781         self.original = original
1782         self.broken = False
1783         self.post_call_notifier = None
1784     def callRemote(self, methname, *args, **kwargs):
1785         def _call():
1786             if self.broken:
1787                 raise IntentionalError("I was asked to break")
1788             meth = getattr(self.original, "remote_" + methname)
1789             return meth(*args, **kwargs)
1790         d = fireEventually()
1791         d.addCallback(lambda res: _call())
1792         if self.post_call_notifier:
1793             d.addCallback(self.post_call_notifier, methname)
1794         return d
1795
1796 class LessFakeClient(FakeClient):
1797
1798     def __init__(self, basedir, num_peers=10):
1799         self._num_peers = num_peers
1800         peerids = [tagged_hash("peerid", "%d" % i)[:20] 
1801                    for i in range(self._num_peers)]
1802         self.storage_broker = StorageFarmBroker()
1803         for peerid in peerids:
1804             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1805             make_dirs(peerdir)
1806             ss = StorageServer(peerdir, peerid)
1807             lw = LocalWrapper(ss)
1808             self.storage_broker.add_server(peerid, lw)
1809         self.nodeid = "fakenodeid"
1810
1811
1812 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1813     def test_publish_surprise(self):
1814         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1815         self.client = LessFakeClient(basedir)
1816         d = self.client.create_mutable_file("contents 1")
1817         def _created(n):
1818             d = defer.succeed(None)
1819             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1820             def _got_smap1(smap):
1821                 # stash the old state of the file
1822                 self.old_map = smap
1823             d.addCallback(_got_smap1)
1824             # then modify the file, leaving the old map untouched
1825             d.addCallback(lambda res: log.msg("starting winning write"))
1826             d.addCallback(lambda res: n.overwrite("contents 2"))
1827             # now attempt to modify the file with the old servermap. This
1828             # will look just like an uncoordinated write, in which every
1829             # single share got updated between our mapupdate and our publish
1830             d.addCallback(lambda res: log.msg("starting doomed write"))
1831             d.addCallback(lambda res:
1832                           self.shouldFail(UncoordinatedWriteError,
1833                                           "test_publish_surprise", None,
1834                                           n.upload,
1835                                           "contents 2a", self.old_map))
1836             return d
1837         d.addCallback(_created)
1838         return d
1839
1840     def test_retrieve_surprise(self):
1841         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1842         self.client = LessFakeClient(basedir)
1843         d = self.client.create_mutable_file("contents 1")
1844         def _created(n):
1845             d = defer.succeed(None)
1846             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1847             def _got_smap1(smap):
1848                 # stash the old state of the file
1849                 self.old_map = smap
1850             d.addCallback(_got_smap1)
1851             # then modify the file, leaving the old map untouched
1852             d.addCallback(lambda res: log.msg("starting winning write"))
1853             d.addCallback(lambda res: n.overwrite("contents 2"))
1854             # now attempt to retrieve the old version with the old servermap.
1855             # This will look like someone has changed the file since we
1856             # updated the servermap.
1857             d.addCallback(lambda res: n._cache._clear())
1858             d.addCallback(lambda res: log.msg("starting doomed read"))
1859             d.addCallback(lambda res:
1860                           self.shouldFail(NotEnoughSharesError,
1861                                           "test_retrieve_surprise",
1862                                           "ran out of peers: have 0 shares (k=3)",
1863                                           n.download_version,
1864                                           self.old_map,
1865                                           self.old_map.best_recoverable_version(),
1866                                           ))
1867             return d
1868         d.addCallback(_created)
1869         return d
1870
1871     def test_unexpected_shares(self):
1872         # upload the file, take a servermap, shut down one of the servers,
1873         # upload it again (causing shares to appear on a new server), then
1874         # upload using the old servermap. The last upload should fail with an
1875         # UncoordinatedWriteError, because of the shares that didn't appear
1876         # in the servermap.
1877         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1878         self.client = LessFakeClient(basedir)
1879         d = self.client.create_mutable_file("contents 1")
1880         def _created(n):
1881             d = defer.succeed(None)
1882             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1883             def _got_smap1(smap):
1884                 # stash the old state of the file
1885                 self.old_map = smap
1886                 # now shut down one of the servers
1887                 peer0 = list(smap.make_sharemap()[0])[0]
1888                 self.client.debug_remove_connection(peer0)
1889                 # then modify the file, leaving the old map untouched
1890                 log.msg("starting winning write")
1891                 return n.overwrite("contents 2")
1892             d.addCallback(_got_smap1)
1893             # now attempt to modify the file with the old servermap. This
1894             # will look just like an uncoordinated write, in which every
1895             # single share got updated between our mapupdate and our publish
1896             d.addCallback(lambda res: log.msg("starting doomed write"))
1897             d.addCallback(lambda res:
1898                           self.shouldFail(UncoordinatedWriteError,
1899                                           "test_surprise", None,
1900                                           n.upload,
1901                                           "contents 2a", self.old_map))
1902             return d
1903         d.addCallback(_created)
1904         return d
1905
1906     def test_bad_server(self):
1907         # Break one server, then create the file: the initial publish should
1908         # complete with an alternate server. Breaking a second server should
1909         # not prevent an update from succeeding either.
1910         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1911         self.client = LessFakeClient(basedir, 20)
1912         # to make sure that one of the initial peers is broken, we have to
1913         # get creative. We create the keys, so we can figure out the storage
1914         # index, but we hold off on doing the initial publish until we've
1915         # broken the server on which the first share wants to be stored.
1916         n = FastMutableFileNode(self.client)
1917         d = defer.succeed(None)
1918         d.addCallback(n._generate_pubprivkeys)
1919         d.addCallback(n._generated)
1920         def _break_peer0(res):
1921             si = n.get_storage_index()
1922             peerlist = list(self.client.storage_broker.get_servers(si))
1923             peerid0, connection0 = peerlist[0]
1924             peerid1, connection1 = peerlist[1]
1925             connection0.broken = True
1926             self.connection1 = connection1
1927         d.addCallback(_break_peer0)
1928         # now let the initial publish finally happen
1929         d.addCallback(lambda res: n._upload("contents 1", None))
1930         # that ought to work
1931         d.addCallback(lambda res: n.download_best_version())
1932         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1933         # now break the second peer
1934         def _break_peer1(res):
1935             self.connection1.broken = True
1936         d.addCallback(_break_peer1)
1937         d.addCallback(lambda res: n.overwrite("contents 2"))
1938         # that ought to work too
1939         d.addCallback(lambda res: n.download_best_version())
1940         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1941         def _explain_error(f):
1942             print f
1943             if f.check(NotEnoughServersError):
1944                 print "first_error:", f.value.first_error
1945             return f
1946         d.addErrback(_explain_error)
1947         return d
1948
1949     def test_bad_server_overlap(self):
1950         # like test_bad_server, but with no extra unused servers to fall back
1951         # upon. This means that we must re-use a server which we've already
1952         # used. If we don't remember the fact that we sent them one share
1953         # already, we'll mistakenly think we're experiencing an
1954         # UncoordinatedWriteError.
1955
1956         # Break one server, then create the file: the initial publish should
1957         # complete with an alternate server. Breaking a second server should
1958         # not prevent an update from succeeding either.
1959         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1960         self.client = LessFakeClient(basedir, 10)
1961         sb = self.client.get_storage_broker()
1962
1963         peerids = list(sb.get_all_serverids())
1964         self.client.debug_break_connection(peerids[0])
1965
1966         d = self.client.create_mutable_file("contents 1")
1967         def _created(n):
1968             d = n.download_best_version()
1969             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1970             # now break one of the remaining servers
1971             def _break_second_server(res):
1972                 self.client.debug_break_connection(peerids[1])
1973             d.addCallback(_break_second_server)
1974             d.addCallback(lambda res: n.overwrite("contents 2"))
1975             # that ought to work too
1976             d.addCallback(lambda res: n.download_best_version())
1977             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1978             return d
1979         d.addCallback(_created)
1980         return d
1981
1982     def test_publish_all_servers_bad(self):
1983         # Break all servers: the publish should fail
1984         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1985         self.client = LessFakeClient(basedir, 20)
1986         sb = self.client.get_storage_broker()
1987         for peerid in sb.get_all_serverids():
1988             self.client.debug_break_connection(peerid)
1989         d = self.shouldFail(NotEnoughServersError,
1990                             "test_publish_all_servers_bad",
1991                             "Ran out of non-bad servers",
1992                             self.client.create_mutable_file, "contents")
1993         return d
1994
1995     def test_publish_no_servers(self):
1996         # no servers at all: the publish should fail
1997         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1998         self.client = LessFakeClient(basedir, 0)
1999         d = self.shouldFail(NotEnoughServersError,
2000                             "test_publish_no_servers",
2001                             "Ran out of non-bad servers",
2002                             self.client.create_mutable_file, "contents")
2003         return d
2004     test_publish_no_servers.timeout = 30
2005
2006
2007     def test_privkey_query_error(self):
2008         # when a servermap is updated with MODE_WRITE, it tries to get the
2009         # privkey. Something might go wrong during this query attempt.
2010         self.client = FakeClient(20)
2011         # we need some contents that are large enough to push the privkey out
2012         # of the early part of the file
2013         LARGE = "These are Larger contents" * 200 # about 5KB
2014         d = self.client.create_mutable_file(LARGE)
2015         def _created(n):
2016             self.uri = n.get_uri()
2017             self.n2 = self.client.create_node_from_uri(self.uri)
2018             # we start by doing a map update to figure out which is the first
2019             # server.
2020             return n.get_servermap(MODE_WRITE)
2021         d.addCallback(_created)
2022         d.addCallback(lambda res: fireEventually(res))
2023         def _got_smap1(smap):
2024             peer0 = list(smap.make_sharemap()[0])[0]
2025             # we tell the server to respond to this peer first, so that it
2026             # will be asked for the privkey first
2027             self.client._storage._sequence = [peer0]
2028             # now we make the peer fail their second query
2029             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2030         d.addCallback(_got_smap1)
2031         # now we update a servermap from a new node (which doesn't have the
2032         # privkey yet, forcing it to use a separate privkey query). Each
2033         # query response will trigger a privkey query, and since we're using
2034         # _sequence to make the peer0 response come back first, we'll send it
2035         # a privkey query first, and _sequence will again ensure that the
2036         # peer0 query will also come back before the others, and then
2037         # _special_answers will make sure that the query raises an exception.
2038         # The whole point of these hijinks is to exercise the code in
2039         # _privkey_query_failed. Note that the map-update will succeed, since
2040         # we'll just get a copy from one of the other shares.
2041         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2042         # Using FakeStorage._sequence means there will be read requests still
2043         # floating around.. wait for them to retire
2044         def _cancel_timer(res):
2045             if self.client._storage._pending_timer:
2046                 self.client._storage._pending_timer.cancel()
2047             return res
2048         d.addBoth(_cancel_timer)
2049         return d
2050
2051     def test_privkey_query_missing(self):
2052         # like test_privkey_query_error, but the shares are deleted by the
2053         # second query, instead of raising an exception.
2054         self.client = FakeClient(20)
2055         LARGE = "These are Larger contents" * 200 # about 5KB
2056         d = self.client.create_mutable_file(LARGE)
2057         def _created(n):
2058             self.uri = n.get_uri()
2059             self.n2 = self.client.create_node_from_uri(self.uri)
2060             return n.get_servermap(MODE_WRITE)
2061         d.addCallback(_created)
2062         d.addCallback(lambda res: fireEventually(res))
2063         def _got_smap1(smap):
2064             peer0 = list(smap.make_sharemap()[0])[0]
2065             self.client._storage._sequence = [peer0]
2066             self.client._storage._special_answers[peer0] = ["normal", "none"]
2067         d.addCallback(_got_smap1)
2068         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2069         def _cancel_timer(res):
2070             if self.client._storage._pending_timer:
2071                 self.client._storage._pending_timer.cancel()
2072             return res
2073         d.addBoth(_cancel_timer)
2074         return d