]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_util: get almost full test coverage of dictutil, starting with the original...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, storage
8 from allmydata.immutable import download
9 from allmydata.util import base32, idlib
10 from allmydata.util.idlib import shortnodeid_b2a
11 from allmydata.util.hashutil import tagged_hash
12 from allmydata.util.fileutil import make_dirs
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
14      FileTooLargeError, NotEnoughSharesError, IRepairResults
15 from allmydata.monitor import Monitor
16 from allmydata.test.common import ShouldFailMixin
17 from foolscap.eventual import eventually, fireEventually
18 from foolscap.logging import log
19 import sha
20
21 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
22 from allmydata.mutable.common import ResponseCache, \
23      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
24      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
25      NotEnoughServersError, CorruptShareError
26 from allmydata.mutable.retrieve import Retrieve
27 from allmydata.mutable.publish import Publish
28 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
29 from allmydata.mutable.layout import unpack_header, unpack_share
30 from allmydata.mutable.repairer import MustForceRepairError
31
32 import common_util as testutil
33
34 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
35 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
36 # of this.
37
38 class FastMutableFileNode(MutableFileNode):
39     SIGNATURE_KEY_SIZE = 522
40
41 # this "FakeStorage" exists to put the share data in RAM and avoid using real
42 # network connections, both to speed up the tests and to reduce the amount of
43 # non-mutable.py code being exercised.
44
45 class FakeStorage:
46     # this class replaces the collection of storage servers, allowing the
47     # tests to examine and manipulate the published shares. It also lets us
48     # control the order in which read queries are answered, to exercise more
49     # of the error-handling code in Retrieve .
50     #
51     # Note that we ignore the storage index: this FakeStorage instance can
52     # only be used for a single storage index.
53
54
55     def __init__(self):
56         self._peers = {}
57         # _sequence is used to cause the responses to occur in a specific
58         # order. If it is in use, then we will defer queries instead of
59         # answering them right away, accumulating the Deferreds in a dict. We
60         # don't know exactly how many queries we'll get, so exactly one
61         # second after the first query arrives, we will release them all (in
62         # order).
63         self._sequence = None
64         self._pending = {}
65         self._pending_timer = None
66         self._special_answers = {}
67
68     def read(self, peerid, storage_index):
69         shares = self._peers.get(peerid, {})
70         if self._special_answers.get(peerid, []):
71             mode = self._special_answers[peerid].pop(0)
72             if mode == "fail":
73                 shares = failure.Failure(IntentionalError())
74             elif mode == "none":
75                 shares = {}
76             elif mode == "normal":
77                 pass
78         if self._sequence is None:
79             return defer.succeed(shares)
80         d = defer.Deferred()
81         if not self._pending:
82             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
83         self._pending[peerid] = (d, shares)
84         return d
85
86     def _fire_readers(self):
87         self._pending_timer = None
88         pending = self._pending
89         self._pending = {}
90         extra = []
91         for peerid in self._sequence:
92             if peerid in pending:
93                 d, shares = pending.pop(peerid)
94                 eventually(d.callback, shares)
95         for (d, shares) in pending.values():
96             eventually(d.callback, shares)
97
98     def write(self, peerid, storage_index, shnum, offset, data):
99         if peerid not in self._peers:
100             self._peers[peerid] = {}
101         shares = self._peers[peerid]
102         f = StringIO()
103         f.write(shares.get(shnum, ""))
104         f.seek(offset)
105         f.write(data)
106         shares[shnum] = f.getvalue()
107
108
109 class FakeStorageServer:
110     def __init__(self, peerid, storage):
111         self.peerid = peerid
112         self.storage = storage
113         self.queries = 0
114     def callRemote(self, methname, *args, **kwargs):
115         def _call():
116             meth = getattr(self, methname)
117             return meth(*args, **kwargs)
118         d = fireEventually()
119         d.addCallback(lambda res: _call())
120         return d
121     def callRemoteOnly(self, methname, *args, **kwargs):
122         d = self.callRemote(methname, *args, **kwargs)
123         d.addBoth(lambda ignore: None)
124         pass
125
126     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
127         pass
128
129     def slot_readv(self, storage_index, shnums, readv):
130         d = self.storage.read(self.peerid, storage_index)
131         def _read(shares):
132             response = {}
133             for shnum in shares:
134                 if shnums and shnum not in shnums:
135                     continue
136                 vector = response[shnum] = []
137                 for (offset, length) in readv:
138                     assert isinstance(offset, (int, long)), offset
139                     assert isinstance(length, (int, long)), length
140                     vector.append(shares[shnum][offset:offset+length])
141             return response
142         d.addCallback(_read)
143         return d
144
145     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
146                                         tw_vectors, read_vector):
147         # always-pass: parrot the test vectors back to them.
148         readv = {}
149         for shnum, (testv, writev, new_length) in tw_vectors.items():
150             for (offset, length, op, specimen) in testv:
151                 assert op in ("le", "eq", "ge")
152             # TODO: this isn't right, the read is controlled by read_vector,
153             # not by testv
154             readv[shnum] = [ specimen
155                              for (offset, length, op, specimen)
156                              in testv ]
157             for (offset, data) in writev:
158                 self.storage.write(self.peerid, storage_index, shnum,
159                                    offset, data)
160         answer = (True, readv)
161         return fireEventually(answer)
162
163
164 # our "FakeClient" has just enough functionality of the real Client to let
165 # the tests run.
166
167 class FakeClient:
168     mutable_file_node_class = FastMutableFileNode
169
170     def __init__(self, num_peers=10):
171         self._storage = FakeStorage()
172         self._num_peers = num_peers
173         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
174                          for i in range(self._num_peers)]
175         self._connections = dict([(peerid, FakeStorageServer(peerid,
176                                                              self._storage))
177                                   for peerid in self._peerids])
178         self.nodeid = "fakenodeid"
179
180     def get_encoding_parameters(self):
181         return {"k": 3, "n": 10}
182
183     def log(self, msg, **kw):
184         return log.msg(msg, **kw)
185
186     def get_renewal_secret(self):
187         return "I hereby permit you to renew my files"
188     def get_cancel_secret(self):
189         return "I hereby permit you to cancel my leases"
190
191     def create_mutable_file(self, contents=""):
192         n = self.mutable_file_node_class(self)
193         d = n.create(contents)
194         d.addCallback(lambda res: n)
195         return d
196
197     def get_history(self):
198         return None
199
200     def create_node_from_uri(self, u):
201         u = IURI(u)
202         assert IMutableFileURI.providedBy(u), u
203         res = self.mutable_file_node_class(self).init_from_uri(u)
204         return res
205
206     def get_permuted_peers(self, service_name, key):
207         """
208         @return: list of (peerid, connection,)
209         """
210         results = []
211         for (peerid, connection) in self._connections.items():
212             assert isinstance(peerid, str)
213             permuted = sha.new(key + peerid).digest()
214             results.append((permuted, peerid, connection))
215         results.sort()
216         results = [ (r[1],r[2]) for r in results]
217         return results
218
219     def upload(self, uploadable):
220         assert IUploadable.providedBy(uploadable)
221         d = uploadable.get_size()
222         d.addCallback(lambda length: uploadable.read(length))
223         #d.addCallback(self.create_mutable_file)
224         def _got_data(datav):
225             data = "".join(datav)
226             #newnode = FastMutableFileNode(self)
227             return uri.LiteralFileURI(data)
228         d.addCallback(_got_data)
229         return d
230
231
232 def flip_bit(original, byte_offset):
233     return (original[:byte_offset] +
234             chr(ord(original[byte_offset]) ^ 0x01) +
235             original[byte_offset+1:])
236
237 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
238     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
239     # list of shnums to corrupt.
240     for peerid in s._peers:
241         shares = s._peers[peerid]
242         for shnum in shares:
243             if (shnums_to_corrupt is not None
244                 and shnum not in shnums_to_corrupt):
245                 continue
246             data = shares[shnum]
247             (version,
248              seqnum,
249              root_hash,
250              IV,
251              k, N, segsize, datalen,
252              o) = unpack_header(data)
253             if isinstance(offset, tuple):
254                 offset1, offset2 = offset
255             else:
256                 offset1 = offset
257                 offset2 = 0
258             if offset1 == "pubkey":
259                 real_offset = 107
260             elif offset1 in o:
261                 real_offset = o[offset1]
262             else:
263                 real_offset = offset1
264             real_offset = int(real_offset) + offset2 + offset_offset
265             assert isinstance(real_offset, int), offset
266             shares[shnum] = flip_bit(data, real_offset)
267     return res
268
269 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
270     def setUp(self):
271         self.client = FakeClient()
272
273     def test_create(self):
274         d = self.client.create_mutable_file()
275         def _created(n):
276             self.failUnless(isinstance(n, FastMutableFileNode))
277             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
278             peer0 = self.client._peerids[0]
279             shnums = self.client._storage._peers[peer0].keys()
280             self.failUnlessEqual(len(shnums), 1)
281         d.addCallback(_created)
282         return d
283
284     def test_serialize(self):
285         n = MutableFileNode(self.client)
286         calls = []
287         def _callback(*args, **kwargs):
288             self.failUnlessEqual(args, (4,) )
289             self.failUnlessEqual(kwargs, {"foo": 5})
290             calls.append(1)
291             return 6
292         d = n._do_serialized(_callback, 4, foo=5)
293         def _check_callback(res):
294             self.failUnlessEqual(res, 6)
295             self.failUnlessEqual(calls, [1])
296         d.addCallback(_check_callback)
297
298         def _errback():
299             raise ValueError("heya")
300         d.addCallback(lambda res:
301                       self.shouldFail(ValueError, "_check_errback", "heya",
302                                       n._do_serialized, _errback))
303         return d
304
305     def test_upload_and_download(self):
306         d = self.client.create_mutable_file()
307         def _created(n):
308             d = defer.succeed(None)
309             d.addCallback(lambda res: n.get_servermap(MODE_READ))
310             d.addCallback(lambda smap: smap.dump(StringIO()))
311             d.addCallback(lambda sio:
312                           self.failUnless("3-of-10" in sio.getvalue()))
313             d.addCallback(lambda res: n.overwrite("contents 1"))
314             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
315             d.addCallback(lambda res: n.download_best_version())
316             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
317             d.addCallback(lambda res: n.get_size_of_best_version())
318             d.addCallback(lambda size:
319                           self.failUnlessEqual(size, len("contents 1")))
320             d.addCallback(lambda res: n.overwrite("contents 2"))
321             d.addCallback(lambda res: n.download_best_version())
322             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
323             d.addCallback(lambda res: n.download(download.Data()))
324             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
325             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
326             d.addCallback(lambda smap: n.upload("contents 3", smap))
327             d.addCallback(lambda res: n.download_best_version())
328             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
329             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
330             d.addCallback(lambda smap:
331                           n.download_version(smap,
332                                              smap.best_recoverable_version()))
333             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
334             # test a file that is large enough to overcome the
335             # mapupdate-to-retrieve data caching (i.e. make the shares larger
336             # than the default readsize, which is 2000 bytes). A 15kB file
337             # will have 5kB shares.
338             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
339             d.addCallback(lambda res: n.download_best_version())
340             d.addCallback(lambda res:
341                           self.failUnlessEqual(res, "large size file" * 1000))
342             return d
343         d.addCallback(_created)
344         return d
345
346     def test_create_with_initial_contents(self):
347         d = self.client.create_mutable_file("contents 1")
348         def _created(n):
349             d = n.download_best_version()
350             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
351             d.addCallback(lambda res: n.overwrite("contents 2"))
352             d.addCallback(lambda res: n.download_best_version())
353             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
354             return d
355         d.addCallback(_created)
356         return d
357
358     def test_create_with_too_large_contents(self):
359         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
360         d = self.shouldFail(FileTooLargeError, "too_large",
361                             "SDMF is limited to one segment, and %d > %d" %
362                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
363                             self.client.create_mutable_file, BIG)
364         d.addCallback(lambda res: self.client.create_mutable_file("small"))
365         def _created(n):
366             return self.shouldFail(FileTooLargeError, "too_large_2",
367                                    "SDMF is limited to one segment, and %d > %d" %
368                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
369                                    n.overwrite, BIG)
370         d.addCallback(_created)
371         return d
372
373     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
374         d = n.get_servermap(MODE_READ)
375         d.addCallback(lambda servermap: servermap.best_recoverable_version())
376         d.addCallback(lambda verinfo:
377                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
378         return d
379
380     def test_modify(self):
381         def _modifier(old_contents, servermap, first_time):
382             return old_contents + "line2"
383         def _non_modifier(old_contents, servermap, first_time):
384             return old_contents
385         def _none_modifier(old_contents, servermap, first_time):
386             return None
387         def _error_modifier(old_contents, servermap, first_time):
388             raise ValueError("oops")
389         def _toobig_modifier(old_contents, servermap, first_time):
390             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
391         calls = []
392         def _ucw_error_modifier(old_contents, servermap, first_time):
393             # simulate an UncoordinatedWriteError once
394             calls.append(1)
395             if len(calls) <= 1:
396                 raise UncoordinatedWriteError("simulated")
397             return old_contents + "line3"
398         def _ucw_error_non_modifier(old_contents, servermap, first_time):
399             # simulate an UncoordinatedWriteError once, and don't actually
400             # modify the contents on subsequent invocations
401             calls.append(1)
402             if len(calls) <= 1:
403                 raise UncoordinatedWriteError("simulated")
404             return old_contents
405
406         d = self.client.create_mutable_file("line1")
407         def _created(n):
408             d = n.modify(_modifier)
409             d.addCallback(lambda res: n.download_best_version())
410             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
411             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
412
413             d.addCallback(lambda res: n.modify(_non_modifier))
414             d.addCallback(lambda res: n.download_best_version())
415             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
416             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
417
418             d.addCallback(lambda res: n.modify(_none_modifier))
419             d.addCallback(lambda res: n.download_best_version())
420             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
421             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
422
423             d.addCallback(lambda res:
424                           self.shouldFail(ValueError, "error_modifier", None,
425                                           n.modify, _error_modifier))
426             d.addCallback(lambda res: n.download_best_version())
427             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
428             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
429
430             d.addCallback(lambda res:
431                           self.shouldFail(FileTooLargeError, "toobig_modifier",
432                                           "SDMF is limited to one segment",
433                                           n.modify, _toobig_modifier))
434             d.addCallback(lambda res: n.download_best_version())
435             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
436             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
437
438             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
439             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
440             d.addCallback(lambda res: n.download_best_version())
441             d.addCallback(lambda res: self.failUnlessEqual(res,
442                                                            "line1line2line3"))
443             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
444
445             def _reset_ucw_error_modifier(res):
446                 calls[:] = []
447                 return res
448             d.addCallback(_reset_ucw_error_modifier)
449
450             # in practice, this n.modify call should publish twice: the first
451             # one gets a UCWE, the second does not. But our test jig (in
452             # which the modifier raises the UCWE) skips over the first one,
453             # so in this test there will be only one publish, and the seqnum
454             # will only be one larger than the previous test, not two (i.e. 4
455             # instead of 5).
456             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
457             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
458             d.addCallback(lambda res: n.download_best_version())
459             d.addCallback(lambda res: self.failUnlessEqual(res,
460                                                            "line1line2line3"))
461             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
462
463             return d
464         d.addCallback(_created)
465         return d
466
467     def test_modify_backoffer(self):
468         def _modifier(old_contents, servermap, first_time):
469             return old_contents + "line2"
470         calls = []
471         def _ucw_error_modifier(old_contents, servermap, first_time):
472             # simulate an UncoordinatedWriteError once
473             calls.append(1)
474             if len(calls) <= 1:
475                 raise UncoordinatedWriteError("simulated")
476             return old_contents + "line3"
477         def _always_ucw_error_modifier(old_contents, servermap, first_time):
478             raise UncoordinatedWriteError("simulated")
479         def _backoff_stopper(node, f):
480             return f
481         def _backoff_pauser(node, f):
482             d = defer.Deferred()
483             reactor.callLater(0.5, d.callback, None)
484             return d
485
486         # the give-up-er will hit its maximum retry count quickly
487         giveuper = BackoffAgent()
488         giveuper._delay = 0.1
489         giveuper.factor = 1
490
491         d = self.client.create_mutable_file("line1")
492         def _created(n):
493             d = n.modify(_modifier)
494             d.addCallback(lambda res: n.download_best_version())
495             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
496             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
497
498             d.addCallback(lambda res:
499                           self.shouldFail(UncoordinatedWriteError,
500                                           "_backoff_stopper", None,
501                                           n.modify, _ucw_error_modifier,
502                                           _backoff_stopper))
503             d.addCallback(lambda res: n.download_best_version())
504             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
505             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
506
507             def _reset_ucw_error_modifier(res):
508                 calls[:] = []
509                 return res
510             d.addCallback(_reset_ucw_error_modifier)
511             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
512                                                _backoff_pauser))
513             d.addCallback(lambda res: n.download_best_version())
514             d.addCallback(lambda res: self.failUnlessEqual(res,
515                                                            "line1line2line3"))
516             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
517
518             d.addCallback(lambda res:
519                           self.shouldFail(UncoordinatedWriteError,
520                                           "giveuper", None,
521                                           n.modify, _always_ucw_error_modifier,
522                                           giveuper.delay))
523             d.addCallback(lambda res: n.download_best_version())
524             d.addCallback(lambda res: self.failUnlessEqual(res,
525                                                            "line1line2line3"))
526             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
527
528             return d
529         d.addCallback(_created)
530         return d
531
532     def test_upload_and_download_full_size_keys(self):
533         self.client.mutable_file_node_class = MutableFileNode
534         d = self.client.create_mutable_file()
535         def _created(n):
536             d = defer.succeed(None)
537             d.addCallback(lambda res: n.get_servermap(MODE_READ))
538             d.addCallback(lambda smap: smap.dump(StringIO()))
539             d.addCallback(lambda sio:
540                           self.failUnless("3-of-10" in sio.getvalue()))
541             d.addCallback(lambda res: n.overwrite("contents 1"))
542             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
543             d.addCallback(lambda res: n.download_best_version())
544             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
545             d.addCallback(lambda res: n.overwrite("contents 2"))
546             d.addCallback(lambda res: n.download_best_version())
547             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
548             d.addCallback(lambda res: n.download(download.Data()))
549             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
550             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
551             d.addCallback(lambda smap: n.upload("contents 3", smap))
552             d.addCallback(lambda res: n.download_best_version())
553             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
554             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
555             d.addCallback(lambda smap:
556                           n.download_version(smap,
557                                              smap.best_recoverable_version()))
558             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
559             return d
560         d.addCallback(_created)
561         return d
562
563
564 class MakeShares(unittest.TestCase):
565     def test_encrypt(self):
566         c = FakeClient()
567         fn = FastMutableFileNode(c)
568         CONTENTS = "some initial contents"
569         d = fn.create(CONTENTS)
570         def _created(res):
571             p = Publish(fn, None)
572             p.salt = "SALT" * 4
573             p.readkey = "\x00" * 16
574             p.newdata = CONTENTS
575             p.required_shares = 3
576             p.total_shares = 10
577             p.setup_encoding_parameters()
578             return p._encrypt_and_encode()
579         d.addCallback(_created)
580         def _done(shares_and_shareids):
581             (shares, share_ids) = shares_and_shareids
582             self.failUnlessEqual(len(shares), 10)
583             for sh in shares:
584                 self.failUnless(isinstance(sh, str))
585                 self.failUnlessEqual(len(sh), 7)
586             self.failUnlessEqual(len(share_ids), 10)
587         d.addCallback(_done)
588         return d
589
590     def test_generate(self):
591         c = FakeClient()
592         fn = FastMutableFileNode(c)
593         CONTENTS = "some initial contents"
594         d = fn.create(CONTENTS)
595         def _created(res):
596             p = Publish(fn, None)
597             self._p = p
598             p.newdata = CONTENTS
599             p.required_shares = 3
600             p.total_shares = 10
601             p.setup_encoding_parameters()
602             p._new_seqnum = 3
603             p.salt = "SALT" * 4
604             # make some fake shares
605             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
606             p._privkey = fn.get_privkey()
607             p._encprivkey = fn.get_encprivkey()
608             p._pubkey = fn.get_pubkey()
609             return p._generate_shares(shares_and_ids)
610         d.addCallback(_created)
611         def _generated(res):
612             p = self._p
613             final_shares = p.shares
614             root_hash = p.root_hash
615             self.failUnlessEqual(len(root_hash), 32)
616             self.failUnless(isinstance(final_shares, dict))
617             self.failUnlessEqual(len(final_shares), 10)
618             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
619             for i,sh in final_shares.items():
620                 self.failUnless(isinstance(sh, str))
621                 # feed the share through the unpacker as a sanity-check
622                 pieces = unpack_share(sh)
623                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
624                  pubkey, signature, share_hash_chain, block_hash_tree,
625                  share_data, enc_privkey) = pieces
626                 self.failUnlessEqual(u_seqnum, 3)
627                 self.failUnlessEqual(u_root_hash, root_hash)
628                 self.failUnlessEqual(k, 3)
629                 self.failUnlessEqual(N, 10)
630                 self.failUnlessEqual(segsize, 21)
631                 self.failUnlessEqual(datalen, len(CONTENTS))
632                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
633                 sig_material = struct.pack(">BQ32s16s BBQQ",
634                                            0, p._new_seqnum, root_hash, IV,
635                                            k, N, segsize, datalen)
636                 self.failUnless(p._pubkey.verify(sig_material, signature))
637                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
638                 self.failUnless(isinstance(share_hash_chain, dict))
639                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
640                 for shnum,share_hash in share_hash_chain.items():
641                     self.failUnless(isinstance(shnum, int))
642                     self.failUnless(isinstance(share_hash, str))
643                     self.failUnlessEqual(len(share_hash), 32)
644                 self.failUnless(isinstance(block_hash_tree, list))
645                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
646                 self.failUnlessEqual(IV, "SALT"*4)
647                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
648                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
649         d.addCallback(_generated)
650         return d
651
652     # TODO: when we publish to 20 peers, we should get one share per peer on 10
653     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
654     # when we publish to zero peers, we should get a NotEnoughSharesError
655
656 class PublishMixin:
657     def publish_one(self):
658         # publish a file and create shares, which can then be manipulated
659         # later.
660         self.CONTENTS = "New contents go here" * 1000
661         num_peers = 20
662         self._client = FakeClient(num_peers)
663         self._storage = self._client._storage
664         d = self._client.create_mutable_file(self.CONTENTS)
665         def _created(node):
666             self._fn = node
667             self._fn2 = self._client.create_node_from_uri(node.get_uri())
668         d.addCallback(_created)
669         return d
670     def publish_multiple(self):
671         self.CONTENTS = ["Contents 0",
672                          "Contents 1",
673                          "Contents 2",
674                          "Contents 3a",
675                          "Contents 3b"]
676         self._copied_shares = {}
677         num_peers = 20
678         self._client = FakeClient(num_peers)
679         self._storage = self._client._storage
680         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
681         def _created(node):
682             self._fn = node
683             # now create multiple versions of the same file, and accumulate
684             # their shares, so we can mix and match them later.
685             d = defer.succeed(None)
686             d.addCallback(self._copy_shares, 0)
687             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
688             d.addCallback(self._copy_shares, 1)
689             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
690             d.addCallback(self._copy_shares, 2)
691             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
692             d.addCallback(self._copy_shares, 3)
693             # now we replace all the shares with version s3, and upload a new
694             # version to get s4b.
695             rollback = dict([(i,2) for i in range(10)])
696             d.addCallback(lambda res: self._set_versions(rollback))
697             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
698             d.addCallback(self._copy_shares, 4)
699             # we leave the storage in state 4
700             return d
701         d.addCallback(_created)
702         return d
703
704     def _copy_shares(self, ignored, index):
705         shares = self._client._storage._peers
706         # we need a deep copy
707         new_shares = {}
708         for peerid in shares:
709             new_shares[peerid] = {}
710             for shnum in shares[peerid]:
711                 new_shares[peerid][shnum] = shares[peerid][shnum]
712         self._copied_shares[index] = new_shares
713
714     def _set_versions(self, versionmap):
715         # versionmap maps shnums to which version (0,1,2,3,4) we want the
716         # share to be at. Any shnum which is left out of the map will stay at
717         # its current version.
718         shares = self._client._storage._peers
719         oldshares = self._copied_shares
720         for peerid in shares:
721             for shnum in shares[peerid]:
722                 if shnum in versionmap:
723                     index = versionmap[shnum]
724                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
725
726
727 class Servermap(unittest.TestCase, PublishMixin):
728     def setUp(self):
729         return self.publish_one()
730
731     def make_servermap(self, mode=MODE_CHECK, fn=None):
732         if fn is None:
733             fn = self._fn
734         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
735         d = smu.update()
736         return d
737
738     def update_servermap(self, oldmap, mode=MODE_CHECK):
739         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
740         d = smu.update()
741         return d
742
743     def failUnlessOneRecoverable(self, sm, num_shares):
744         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
745         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
746         best = sm.best_recoverable_version()
747         self.failIfEqual(best, None)
748         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
749         self.failUnlessEqual(len(sm.shares_available()), 1)
750         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
751         shnum, peerids = sm.make_sharemap().items()[0]
752         peerid = list(peerids)[0]
753         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
754         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
755         return sm
756
757     def test_basic(self):
758         d = defer.succeed(None)
759         ms = self.make_servermap
760         us = self.update_servermap
761
762         d.addCallback(lambda res: ms(mode=MODE_CHECK))
763         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
764         d.addCallback(lambda res: ms(mode=MODE_WRITE))
765         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
766         d.addCallback(lambda res: ms(mode=MODE_READ))
767         # this more stops at k+epsilon, and epsilon=k, so 6 shares
768         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
769         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
770         # this mode stops at 'k' shares
771         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
772
773         # and can we re-use the same servermap? Note that these are sorted in
774         # increasing order of number of servers queried, since once a server
775         # gets into the servermap, we'll always ask it for an update.
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
777         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
778         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
779         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
780         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
781         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
782         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
783         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
784         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
785
786         return d
787
788     def test_fetch_privkey(self):
789         d = defer.succeed(None)
790         # use the sibling filenode (which hasn't been used yet), and make
791         # sure it can fetch the privkey. The file is small, so the privkey
792         # will be fetched on the first (query) pass.
793         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
794         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
795
796         # create a new file, which is large enough to knock the privkey out
797         # of the early part of the file
798         LARGE = "These are Larger contents" * 200 # about 5KB
799         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
800         def _created(large_fn):
801             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
802             return self.make_servermap(MODE_WRITE, large_fn2)
803         d.addCallback(_created)
804         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
805         return d
806
807     def test_mark_bad(self):
808         d = defer.succeed(None)
809         ms = self.make_servermap
810         us = self.update_servermap
811
812         d.addCallback(lambda res: ms(mode=MODE_READ))
813         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
814         def _made_map(sm):
815             v = sm.best_recoverable_version()
816             vm = sm.make_versionmap()
817             shares = list(vm[v])
818             self.failUnlessEqual(len(shares), 6)
819             self._corrupted = set()
820             # mark the first 5 shares as corrupt, then update the servermap.
821             # The map should not have the marked shares it in any more, and
822             # new shares should be found to replace the missing ones.
823             for (shnum, peerid, timestamp) in shares:
824                 if shnum < 5:
825                     self._corrupted.add( (peerid, shnum) )
826                     sm.mark_bad_share(peerid, shnum, "")
827             return self.update_servermap(sm, MODE_WRITE)
828         d.addCallback(_made_map)
829         def _check_map(sm):
830             # this should find all 5 shares that weren't marked bad
831             v = sm.best_recoverable_version()
832             vm = sm.make_versionmap()
833             shares = list(vm[v])
834             for (peerid, shnum) in self._corrupted:
835                 peer_shares = sm.shares_on_peer(peerid)
836                 self.failIf(shnum in peer_shares,
837                             "%d was in %s" % (shnum, peer_shares))
838             self.failUnlessEqual(len(shares), 5)
839         d.addCallback(_check_map)
840         return d
841
842     def failUnlessNoneRecoverable(self, sm):
843         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
844         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
845         best = sm.best_recoverable_version()
846         self.failUnlessEqual(best, None)
847         self.failUnlessEqual(len(sm.shares_available()), 0)
848
849     def test_no_shares(self):
850         self._client._storage._peers = {} # delete all shares
851         ms = self.make_servermap
852         d = defer.succeed(None)
853
854         d.addCallback(lambda res: ms(mode=MODE_CHECK))
855         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
856
857         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
858         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
859
860         d.addCallback(lambda res: ms(mode=MODE_WRITE))
861         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
862
863         d.addCallback(lambda res: ms(mode=MODE_READ))
864         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
865
866         return d
867
868     def failUnlessNotQuiteEnough(self, sm):
869         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
870         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
871         best = sm.best_recoverable_version()
872         self.failUnlessEqual(best, None)
873         self.failUnlessEqual(len(sm.shares_available()), 1)
874         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
875         return sm
876
877     def test_not_quite_enough_shares(self):
878         s = self._client._storage
879         ms = self.make_servermap
880         num_shares = len(s._peers)
881         for peerid in s._peers:
882             s._peers[peerid] = {}
883             num_shares -= 1
884             if num_shares == 2:
885                 break
886         # now there ought to be only two shares left
887         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
888
889         d = defer.succeed(None)
890
891         d.addCallback(lambda res: ms(mode=MODE_CHECK))
892         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
893         d.addCallback(lambda sm:
894                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
895         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
896         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
897         d.addCallback(lambda res: ms(mode=MODE_WRITE))
898         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
899         d.addCallback(lambda res: ms(mode=MODE_READ))
900         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
901
902         return d
903
904
905
906 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
907     def setUp(self):
908         return self.publish_one()
909
910     def make_servermap(self, mode=MODE_READ, oldmap=None):
911         if oldmap is None:
912             oldmap = ServerMap()
913         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
914         d = smu.update()
915         return d
916
917     def abbrev_verinfo(self, verinfo):
918         if verinfo is None:
919             return None
920         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
921          offsets_tuple) = verinfo
922         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
923
924     def abbrev_verinfo_dict(self, verinfo_d):
925         output = {}
926         for verinfo,value in verinfo_d.items():
927             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
928              offsets_tuple) = verinfo
929             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
930         return output
931
932     def dump_servermap(self, servermap):
933         print "SERVERMAP", servermap
934         print "RECOVERABLE", [self.abbrev_verinfo(v)
935                               for v in servermap.recoverable_versions()]
936         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
937         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
938
939     def do_download(self, servermap, version=None):
940         if version is None:
941             version = servermap.best_recoverable_version()
942         r = Retrieve(self._fn, servermap, version)
943         return r.download()
944
945     def test_basic(self):
946         d = self.make_servermap()
947         def _do_retrieve(servermap):
948             self._smap = servermap
949             #self.dump_servermap(servermap)
950             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
951             return self.do_download(servermap)
952         d.addCallback(_do_retrieve)
953         def _retrieved(new_contents):
954             self.failUnlessEqual(new_contents, self.CONTENTS)
955         d.addCallback(_retrieved)
956         # we should be able to re-use the same servermap, both with and
957         # without updating it.
958         d.addCallback(lambda res: self.do_download(self._smap))
959         d.addCallback(_retrieved)
960         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
961         d.addCallback(lambda res: self.do_download(self._smap))
962         d.addCallback(_retrieved)
963         # clobbering the pubkey should make the servermap updater re-fetch it
964         def _clobber_pubkey(res):
965             self._fn._pubkey = None
966         d.addCallback(_clobber_pubkey)
967         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
968         d.addCallback(lambda res: self.do_download(self._smap))
969         d.addCallback(_retrieved)
970         return d
971
972     def test_all_shares_vanished(self):
973         d = self.make_servermap()
974         def _remove_shares(servermap):
975             for shares in self._storage._peers.values():
976                 shares.clear()
977             d1 = self.shouldFail(NotEnoughSharesError,
978                                  "test_all_shares_vanished",
979                                  "ran out of peers",
980                                  self.do_download, servermap)
981             return d1
982         d.addCallback(_remove_shares)
983         return d
984
985     def test_no_servers(self):
986         c2 = FakeClient(0)
987         self._fn._client = c2
988         # if there are no servers, then a MODE_READ servermap should come
989         # back empty
990         d = self.make_servermap()
991         def _check_servermap(servermap):
992             self.failUnlessEqual(servermap.best_recoverable_version(), None)
993             self.failIf(servermap.recoverable_versions())
994             self.failIf(servermap.unrecoverable_versions())
995             self.failIf(servermap.all_peers())
996         d.addCallback(_check_servermap)
997         return d
998     test_no_servers.timeout = 15
999
1000     def test_no_servers_download(self):
1001         c2 = FakeClient(0)
1002         self._fn._client = c2
1003         d = self.shouldFail(UnrecoverableFileError,
1004                             "test_no_servers_download",
1005                             "no recoverable versions",
1006                             self._fn.download_best_version)
1007         def _restore(res):
1008             # a failed download that occurs while we aren't connected to
1009             # anybody should not prevent a subsequent download from working.
1010             # This isn't quite the webapi-driven test that #463 wants, but it
1011             # should be close enough.
1012             self._fn._client = self._client
1013             return self._fn.download_best_version()
1014         def _retrieved(new_contents):
1015             self.failUnlessEqual(new_contents, self.CONTENTS)
1016         d.addCallback(_restore)
1017         d.addCallback(_retrieved)
1018         return d
1019     test_no_servers_download.timeout = 15
1020
1021     def _test_corrupt_all(self, offset, substring,
1022                           should_succeed=False, corrupt_early=True,
1023                           failure_checker=None):
1024         d = defer.succeed(None)
1025         if corrupt_early:
1026             d.addCallback(corrupt, self._storage, offset)
1027         d.addCallback(lambda res: self.make_servermap())
1028         if not corrupt_early:
1029             d.addCallback(corrupt, self._storage, offset)
1030         def _do_retrieve(servermap):
1031             ver = servermap.best_recoverable_version()
1032             if ver is None and not should_succeed:
1033                 # no recoverable versions == not succeeding. The problem
1034                 # should be noted in the servermap's list of problems.
1035                 if substring:
1036                     allproblems = [str(f) for f in servermap.problems]
1037                     self.failUnless(substring in "".join(allproblems))
1038                 return servermap
1039             if should_succeed:
1040                 d1 = self._fn.download_version(servermap, ver)
1041                 d1.addCallback(lambda new_contents:
1042                                self.failUnlessEqual(new_contents, self.CONTENTS))
1043             else:
1044                 d1 = self.shouldFail(NotEnoughSharesError,
1045                                      "_corrupt_all(offset=%s)" % (offset,),
1046                                      substring,
1047                                      self._fn.download_version, servermap, ver)
1048             if failure_checker:
1049                 d1.addCallback(failure_checker)
1050             d1.addCallback(lambda res: servermap)
1051             return d1
1052         d.addCallback(_do_retrieve)
1053         return d
1054
1055     def test_corrupt_all_verbyte(self):
1056         # when the version byte is not 0, we hit an assertion error in
1057         # unpack_share().
1058         d = self._test_corrupt_all(0, "AssertionError")
1059         def _check_servermap(servermap):
1060             # and the dump should mention the problems
1061             s = StringIO()
1062             dump = servermap.dump(s).getvalue()
1063             self.failUnless("10 PROBLEMS" in dump, dump)
1064         d.addCallback(_check_servermap)
1065         return d
1066
1067     def test_corrupt_all_seqnum(self):
1068         # a corrupt sequence number will trigger a bad signature
1069         return self._test_corrupt_all(1, "signature is invalid")
1070
1071     def test_corrupt_all_R(self):
1072         # a corrupt root hash will trigger a bad signature
1073         return self._test_corrupt_all(9, "signature is invalid")
1074
1075     def test_corrupt_all_IV(self):
1076         # a corrupt salt/IV will trigger a bad signature
1077         return self._test_corrupt_all(41, "signature is invalid")
1078
1079     def test_corrupt_all_k(self):
1080         # a corrupt 'k' will trigger a bad signature
1081         return self._test_corrupt_all(57, "signature is invalid")
1082
1083     def test_corrupt_all_N(self):
1084         # a corrupt 'N' will trigger a bad signature
1085         return self._test_corrupt_all(58, "signature is invalid")
1086
1087     def test_corrupt_all_segsize(self):
1088         # a corrupt segsize will trigger a bad signature
1089         return self._test_corrupt_all(59, "signature is invalid")
1090
1091     def test_corrupt_all_datalen(self):
1092         # a corrupt data length will trigger a bad signature
1093         return self._test_corrupt_all(67, "signature is invalid")
1094
1095     def test_corrupt_all_pubkey(self):
1096         # a corrupt pubkey won't match the URI's fingerprint. We need to
1097         # remove the pubkey from the filenode, or else it won't bother trying
1098         # to update it.
1099         self._fn._pubkey = None
1100         return self._test_corrupt_all("pubkey",
1101                                       "pubkey doesn't match fingerprint")
1102
1103     def test_corrupt_all_sig(self):
1104         # a corrupt signature is a bad one
1105         # the signature runs from about [543:799], depending upon the length
1106         # of the pubkey
1107         return self._test_corrupt_all("signature", "signature is invalid")
1108
1109     def test_corrupt_all_share_hash_chain_number(self):
1110         # a corrupt share hash chain entry will show up as a bad hash. If we
1111         # mangle the first byte, that will look like a bad hash number,
1112         # causing an IndexError
1113         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1114
1115     def test_corrupt_all_share_hash_chain_hash(self):
1116         # a corrupt share hash chain entry will show up as a bad hash. If we
1117         # mangle a few bytes in, that will look like a bad hash.
1118         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1119
1120     def test_corrupt_all_block_hash_tree(self):
1121         return self._test_corrupt_all("block_hash_tree",
1122                                       "block hash tree failure")
1123
1124     def test_corrupt_all_block(self):
1125         return self._test_corrupt_all("share_data", "block hash tree failure")
1126
1127     def test_corrupt_all_encprivkey(self):
1128         # a corrupted privkey won't even be noticed by the reader, only by a
1129         # writer.
1130         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1131
1132
1133     def test_corrupt_all_seqnum_late(self):
1134         # corrupting the seqnum between mapupdate and retrieve should result
1135         # in NotEnoughSharesError, since each share will look invalid
1136         def _check(res):
1137             f = res[0]
1138             self.failUnless(f.check(NotEnoughSharesError))
1139             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1140         return self._test_corrupt_all(1, "ran out of peers",
1141                                       corrupt_early=False,
1142                                       failure_checker=_check)
1143
1144     def test_corrupt_all_block_hash_tree_late(self):
1145         def _check(res):
1146             f = res[0]
1147             self.failUnless(f.check(NotEnoughSharesError))
1148         return self._test_corrupt_all("block_hash_tree",
1149                                       "block hash tree failure",
1150                                       corrupt_early=False,
1151                                       failure_checker=_check)
1152
1153
1154     def test_corrupt_all_block_late(self):
1155         def _check(res):
1156             f = res[0]
1157             self.failUnless(f.check(NotEnoughSharesError))
1158         return self._test_corrupt_all("share_data", "block hash tree failure",
1159                                       corrupt_early=False,
1160                                       failure_checker=_check)
1161
1162
1163     def test_basic_pubkey_at_end(self):
1164         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1165         # download to succeed but forcing a bunch of retries first. Note that
1166         # this is rather pessimistic: our Retrieve process will throw away
1167         # the whole share if the pubkey is bad, even though the rest of the
1168         # share might be good.
1169
1170         self._fn._pubkey = None
1171         k = self._fn.get_required_shares()
1172         N = self._fn.get_total_shares()
1173         d = defer.succeed(None)
1174         d.addCallback(corrupt, self._storage, "pubkey",
1175                       shnums_to_corrupt=range(0, N-k))
1176         d.addCallback(lambda res: self.make_servermap())
1177         def _do_retrieve(servermap):
1178             self.failUnless(servermap.problems)
1179             self.failUnless("pubkey doesn't match fingerprint"
1180                             in str(servermap.problems[0]))
1181             ver = servermap.best_recoverable_version()
1182             r = Retrieve(self._fn, servermap, ver)
1183             return r.download()
1184         d.addCallback(_do_retrieve)
1185         d.addCallback(lambda new_contents:
1186                       self.failUnlessEqual(new_contents, self.CONTENTS))
1187         return d
1188
1189     def test_corrupt_some(self):
1190         # corrupt the data of first five shares (so the servermap thinks
1191         # they're good but retrieve marks them as bad), so that the
1192         # MODE_READ set of 6 will be insufficient, forcing node.download to
1193         # retry with more servers.
1194         corrupt(None, self._storage, "share_data", range(5))
1195         d = self.make_servermap()
1196         def _do_retrieve(servermap):
1197             ver = servermap.best_recoverable_version()
1198             self.failUnless(ver)
1199             return self._fn.download_best_version()
1200         d.addCallback(_do_retrieve)
1201         d.addCallback(lambda new_contents:
1202                       self.failUnlessEqual(new_contents, self.CONTENTS))
1203         return d
1204
1205     def test_download_fails(self):
1206         corrupt(None, self._storage, "signature")
1207         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1208                             "no recoverable versions",
1209                             self._fn.download_best_version)
1210         return d
1211
1212
1213 class CheckerMixin:
1214     def check_good(self, r, where):
1215         self.failUnless(r.is_healthy(), where)
1216         return r
1217
1218     def check_bad(self, r, where):
1219         self.failIf(r.is_healthy(), where)
1220         return r
1221
1222     def check_expected_failure(self, r, expected_exception, substring, where):
1223         for (peerid, storage_index, shnum, f) in r.problems:
1224             if f.check(expected_exception):
1225                 self.failUnless(substring in str(f),
1226                                 "%s: substring '%s' not in '%s'" %
1227                                 (where, substring, str(f)))
1228                 return
1229         self.fail("%s: didn't see expected exception %s in problems %s" %
1230                   (where, expected_exception, r.problems))
1231
1232
1233 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1234     def setUp(self):
1235         return self.publish_one()
1236
1237
1238     def test_check_good(self):
1239         d = self._fn.check(Monitor())
1240         d.addCallback(self.check_good, "test_check_good")
1241         return d
1242
1243     def test_check_no_shares(self):
1244         for shares in self._storage._peers.values():
1245             shares.clear()
1246         d = self._fn.check(Monitor())
1247         d.addCallback(self.check_bad, "test_check_no_shares")
1248         return d
1249
1250     def test_check_not_enough_shares(self):
1251         for shares in self._storage._peers.values():
1252             for shnum in shares.keys():
1253                 if shnum > 0:
1254                     del shares[shnum]
1255         d = self._fn.check(Monitor())
1256         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1257         return d
1258
1259     def test_check_all_bad_sig(self):
1260         corrupt(None, self._storage, 1) # bad sig
1261         d = self._fn.check(Monitor())
1262         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1263         return d
1264
1265     def test_check_all_bad_blocks(self):
1266         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1267         # the Checker won't notice this.. it doesn't look at actual data
1268         d = self._fn.check(Monitor())
1269         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1270         return d
1271
1272     def test_verify_good(self):
1273         d = self._fn.check(Monitor(), verify=True)
1274         d.addCallback(self.check_good, "test_verify_good")
1275         return d
1276
1277     def test_verify_all_bad_sig(self):
1278         corrupt(None, self._storage, 1) # bad sig
1279         d = self._fn.check(Monitor(), verify=True)
1280         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1281         return d
1282
1283     def test_verify_one_bad_sig(self):
1284         corrupt(None, self._storage, 1, [9]) # bad sig
1285         d = self._fn.check(Monitor(), verify=True)
1286         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1287         return d
1288
1289     def test_verify_one_bad_block(self):
1290         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1291         # the Verifier *will* notice this, since it examines every byte
1292         d = self._fn.check(Monitor(), verify=True)
1293         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1294         d.addCallback(self.check_expected_failure,
1295                       CorruptShareError, "block hash tree failure",
1296                       "test_verify_one_bad_block")
1297         return d
1298
1299     def test_verify_one_bad_sharehash(self):
1300         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1301         d = self._fn.check(Monitor(), verify=True)
1302         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1303         d.addCallback(self.check_expected_failure,
1304                       CorruptShareError, "corrupt hashes",
1305                       "test_verify_one_bad_sharehash")
1306         return d
1307
1308     def test_verify_one_bad_encprivkey(self):
1309         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1310         d = self._fn.check(Monitor(), verify=True)
1311         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1312         d.addCallback(self.check_expected_failure,
1313                       CorruptShareError, "invalid privkey",
1314                       "test_verify_one_bad_encprivkey")
1315         return d
1316
1317     def test_verify_one_bad_encprivkey_uncheckable(self):
1318         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1319         readonly_fn = self._fn.get_readonly()
1320         # a read-only node has no way to validate the privkey
1321         d = readonly_fn.check(Monitor(), verify=True)
1322         d.addCallback(self.check_good,
1323                       "test_verify_one_bad_encprivkey_uncheckable")
1324         return d
1325
1326 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1327
1328     def get_shares(self, s):
1329         all_shares = {} # maps (peerid, shnum) to share data
1330         for peerid in s._peers:
1331             shares = s._peers[peerid]
1332             for shnum in shares:
1333                 data = shares[shnum]
1334                 all_shares[ (peerid, shnum) ] = data
1335         return all_shares
1336
1337     def copy_shares(self, ignored=None):
1338         self.old_shares.append(self.get_shares(self._storage))
1339
1340     def test_repair_nop(self):
1341         self.old_shares = []
1342         d = self.publish_one()
1343         d.addCallback(self.copy_shares)
1344         d.addCallback(lambda res: self._fn.check(Monitor()))
1345         d.addCallback(lambda check_results: self._fn.repair(check_results))
1346         def _check_results(rres):
1347             self.failUnless(IRepairResults.providedBy(rres))
1348             # TODO: examine results
1349
1350             self.copy_shares()
1351
1352             initial_shares = self.old_shares[0]
1353             new_shares = self.old_shares[1]
1354             # TODO: this really shouldn't change anything. When we implement
1355             # a "minimal-bandwidth" repairer", change this test to assert:
1356             #self.failUnlessEqual(new_shares, initial_shares)
1357
1358             # all shares should be in the same place as before
1359             self.failUnlessEqual(set(initial_shares.keys()),
1360                                  set(new_shares.keys()))
1361             # but they should all be at a newer seqnum. The IV will be
1362             # different, so the roothash will be too.
1363             for key in initial_shares:
1364                 (version0,
1365                  seqnum0,
1366                  root_hash0,
1367                  IV0,
1368                  k0, N0, segsize0, datalen0,
1369                  o0) = unpack_header(initial_shares[key])
1370                 (version1,
1371                  seqnum1,
1372                  root_hash1,
1373                  IV1,
1374                  k1, N1, segsize1, datalen1,
1375                  o1) = unpack_header(new_shares[key])
1376                 self.failUnlessEqual(version0, version1)
1377                 self.failUnlessEqual(seqnum0+1, seqnum1)
1378                 self.failUnlessEqual(k0, k1)
1379                 self.failUnlessEqual(N0, N1)
1380                 self.failUnlessEqual(segsize0, segsize1)
1381                 self.failUnlessEqual(datalen0, datalen1)
1382         d.addCallback(_check_results)
1383         return d
1384
1385     def failIfSharesChanged(self, ignored=None):
1386         old_shares = self.old_shares[-2]
1387         current_shares = self.old_shares[-1]
1388         self.failUnlessEqual(old_shares, current_shares)
1389
1390     def test_merge(self):
1391         self.old_shares = []
1392         d = self.publish_multiple()
1393         # repair will refuse to merge multiple highest seqnums unless you
1394         # pass force=True
1395         d.addCallback(lambda res:
1396                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1397                                           1:4,3:4,5:4,7:4,9:4}))
1398         d.addCallback(self.copy_shares)
1399         d.addCallback(lambda res: self._fn.check(Monitor()))
1400         def _try_repair(check_results):
1401             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1402             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1403                                  self._fn.repair, check_results)
1404             d2.addCallback(self.copy_shares)
1405             d2.addCallback(self.failIfSharesChanged)
1406             d2.addCallback(lambda res: check_results)
1407             return d2
1408         d.addCallback(_try_repair)
1409         d.addCallback(lambda check_results:
1410                       self._fn.repair(check_results, force=True))
1411         # this should give us 10 shares of the highest roothash
1412         def _check_repair_results(rres):
1413             pass # TODO
1414         d.addCallback(_check_repair_results)
1415         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1416         def _check_smap(smap):
1417             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1418             self.failIf(smap.unrecoverable_versions())
1419             # now, which should have won?
1420             roothash_s4a = self.get_roothash_for(3)
1421             roothash_s4b = self.get_roothash_for(4)
1422             if roothash_s4b > roothash_s4a:
1423                 expected_contents = self.CONTENTS[4]
1424             else:
1425                 expected_contents = self.CONTENTS[3]
1426             new_versionid = smap.best_recoverable_version()
1427             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1428             d2 = self._fn.download_version(smap, new_versionid)
1429             d2.addCallback(self.failUnlessEqual, expected_contents)
1430             return d2
1431         d.addCallback(_check_smap)
1432         return d
1433
1434     def test_non_merge(self):
1435         self.old_shares = []
1436         d = self.publish_multiple()
1437         # repair should not refuse a repair that doesn't need to merge. In
1438         # this case, we combine v2 with v3. The repair should ignore v2 and
1439         # copy v3 into a new v5.
1440         d.addCallback(lambda res:
1441                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1442                                           1:3,3:3,5:3,7:3,9:3}))
1443         d.addCallback(lambda res: self._fn.check(Monitor()))
1444         d.addCallback(lambda check_results: self._fn.repair(check_results))
1445         # this should give us 10 shares of v3
1446         def _check_repair_results(rres):
1447             pass # TODO
1448         d.addCallback(_check_repair_results)
1449         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1450         def _check_smap(smap):
1451             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1452             self.failIf(smap.unrecoverable_versions())
1453             # now, which should have won?
1454             roothash_s4a = self.get_roothash_for(3)
1455             expected_contents = self.CONTENTS[3]
1456             new_versionid = smap.best_recoverable_version()
1457             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1458             d2 = self._fn.download_version(smap, new_versionid)
1459             d2.addCallback(self.failUnlessEqual, expected_contents)
1460             return d2
1461         d.addCallback(_check_smap)
1462         return d
1463
1464     def get_roothash_for(self, index):
1465         # return the roothash for the first share we see in the saved set
1466         shares = self._copied_shares[index]
1467         for peerid in shares:
1468             for shnum in shares[peerid]:
1469                 share = shares[peerid][shnum]
1470                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1471                           unpack_header(share)
1472                 return root_hash
1473
1474 class MultipleEncodings(unittest.TestCase):
1475     def setUp(self):
1476         self.CONTENTS = "New contents go here"
1477         num_peers = 20
1478         self._client = FakeClient(num_peers)
1479         self._storage = self._client._storage
1480         d = self._client.create_mutable_file(self.CONTENTS)
1481         def _created(node):
1482             self._fn = node
1483         d.addCallback(_created)
1484         return d
1485
1486     def _encode(self, k, n, data):
1487         # encode 'data' into a peerid->shares dict.
1488
1489         fn2 = FastMutableFileNode(self._client)
1490         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1491         # and _fingerprint
1492         fn = self._fn
1493         fn2.init_from_uri(fn.get_uri())
1494         # then we copy over other fields that are normally fetched from the
1495         # existing shares
1496         fn2._pubkey = fn._pubkey
1497         fn2._privkey = fn._privkey
1498         fn2._encprivkey = fn._encprivkey
1499         # and set the encoding parameters to something completely different
1500         fn2._required_shares = k
1501         fn2._total_shares = n
1502
1503         s = self._client._storage
1504         s._peers = {} # clear existing storage
1505         p2 = Publish(fn2, None)
1506         d = p2.publish(data)
1507         def _published(res):
1508             shares = s._peers
1509             s._peers = {}
1510             return shares
1511         d.addCallback(_published)
1512         return d
1513
1514     def make_servermap(self, mode=MODE_READ, oldmap=None):
1515         if oldmap is None:
1516             oldmap = ServerMap()
1517         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1518         d = smu.update()
1519         return d
1520
1521     def test_multiple_encodings(self):
1522         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1523         # then mix up the shares, to make sure that download survives seeing
1524         # a variety of encodings. This is actually kind of tricky to set up.
1525
1526         contents1 = "Contents for encoding 1 (3-of-10) go here"
1527         contents2 = "Contents for encoding 2 (4-of-9) go here"
1528         contents3 = "Contents for encoding 3 (4-of-7) go here"
1529
1530         # we make a retrieval object that doesn't know what encoding
1531         # parameters to use
1532         fn3 = FastMutableFileNode(self._client)
1533         fn3.init_from_uri(self._fn.get_uri())
1534
1535         # now we upload a file through fn1, and grab its shares
1536         d = self._encode(3, 10, contents1)
1537         def _encoded_1(shares):
1538             self._shares1 = shares
1539         d.addCallback(_encoded_1)
1540         d.addCallback(lambda res: self._encode(4, 9, contents2))
1541         def _encoded_2(shares):
1542             self._shares2 = shares
1543         d.addCallback(_encoded_2)
1544         d.addCallback(lambda res: self._encode(4, 7, contents3))
1545         def _encoded_3(shares):
1546             self._shares3 = shares
1547         d.addCallback(_encoded_3)
1548
1549         def _merge(res):
1550             log.msg("merging sharelists")
1551             # we merge the shares from the two sets, leaving each shnum in
1552             # its original location, but using a share from set1 or set2
1553             # according to the following sequence:
1554             #
1555             #  4-of-9  a  s2
1556             #  4-of-9  b  s2
1557             #  4-of-7  c   s3
1558             #  4-of-9  d  s2
1559             #  3-of-9  e s1
1560             #  3-of-9  f s1
1561             #  3-of-9  g s1
1562             #  4-of-9  h  s2
1563             #
1564             # so that neither form can be recovered until fetch [f], at which
1565             # point version-s1 (the 3-of-10 form) should be recoverable. If
1566             # the implementation latches on to the first version it sees,
1567             # then s2 will be recoverable at fetch [g].
1568
1569             # Later, when we implement code that handles multiple versions,
1570             # we can use this framework to assert that all recoverable
1571             # versions are retrieved, and test that 'epsilon' does its job
1572
1573             places = [2, 2, 3, 2, 1, 1, 1, 2]
1574
1575             sharemap = {}
1576
1577             for i,peerid in enumerate(self._client._peerids):
1578                 peerid_s = shortnodeid_b2a(peerid)
1579                 for shnum in self._shares1.get(peerid, {}):
1580                     if shnum < len(places):
1581                         which = places[shnum]
1582                     else:
1583                         which = "x"
1584                     self._client._storage._peers[peerid] = peers = {}
1585                     in_1 = shnum in self._shares1[peerid]
1586                     in_2 = shnum in self._shares2.get(peerid, {})
1587                     in_3 = shnum in self._shares3.get(peerid, {})
1588                     #print peerid_s, shnum, which, in_1, in_2, in_3
1589                     if which == 1:
1590                         if in_1:
1591                             peers[shnum] = self._shares1[peerid][shnum]
1592                             sharemap[shnum] = peerid
1593                     elif which == 2:
1594                         if in_2:
1595                             peers[shnum] = self._shares2[peerid][shnum]
1596                             sharemap[shnum] = peerid
1597                     elif which == 3:
1598                         if in_3:
1599                             peers[shnum] = self._shares3[peerid][shnum]
1600                             sharemap[shnum] = peerid
1601
1602             # we don't bother placing any other shares
1603             # now sort the sequence so that share 0 is returned first
1604             new_sequence = [sharemap[shnum]
1605                             for shnum in sorted(sharemap.keys())]
1606             self._client._storage._sequence = new_sequence
1607             log.msg("merge done")
1608         d.addCallback(_merge)
1609         d.addCallback(lambda res: fn3.download_best_version())
1610         def _retrieved(new_contents):
1611             # the current specified behavior is "first version recoverable"
1612             self.failUnlessEqual(new_contents, contents1)
1613         d.addCallback(_retrieved)
1614         return d
1615
1616
1617 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1618
1619     def setUp(self):
1620         return self.publish_multiple()
1621
1622     def test_multiple_versions(self):
1623         # if we see a mix of versions in the grid, download_best_version
1624         # should get the latest one
1625         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1626         d = self._fn.download_best_version()
1627         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1628         # and the checker should report problems
1629         d.addCallback(lambda res: self._fn.check(Monitor()))
1630         d.addCallback(self.check_bad, "test_multiple_versions")
1631
1632         # but if everything is at version 2, that's what we should download
1633         d.addCallback(lambda res:
1634                       self._set_versions(dict([(i,2) for i in range(10)])))
1635         d.addCallback(lambda res: self._fn.download_best_version())
1636         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1637         # if exactly one share is at version 3, we should still get v2
1638         d.addCallback(lambda res:
1639                       self._set_versions({0:3}))
1640         d.addCallback(lambda res: self._fn.download_best_version())
1641         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1642         # but the servermap should see the unrecoverable version. This
1643         # depends upon the single newer share being queried early.
1644         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1645         def _check_smap(smap):
1646             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1647             newer = smap.unrecoverable_newer_versions()
1648             self.failUnlessEqual(len(newer), 1)
1649             verinfo, health = newer.items()[0]
1650             self.failUnlessEqual(verinfo[0], 4)
1651             self.failUnlessEqual(health, (1,3))
1652             self.failIf(smap.needs_merge())
1653         d.addCallback(_check_smap)
1654         # if we have a mix of two parallel versions (s4a and s4b), we could
1655         # recover either
1656         d.addCallback(lambda res:
1657                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1658                                           1:4,3:4,5:4,7:4,9:4}))
1659         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1660         def _check_smap_mixed(smap):
1661             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1662             newer = smap.unrecoverable_newer_versions()
1663             self.failUnlessEqual(len(newer), 0)
1664             self.failUnless(smap.needs_merge())
1665         d.addCallback(_check_smap_mixed)
1666         d.addCallback(lambda res: self._fn.download_best_version())
1667         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1668                                                   res == self.CONTENTS[4]))
1669         return d
1670
1671     def test_replace(self):
1672         # if we see a mix of versions in the grid, we should be able to
1673         # replace them all with a newer version
1674
1675         # if exactly one share is at version 3, we should download (and
1676         # replace) v2, and the result should be v4. Note that the index we
1677         # give to _set_versions is different than the sequence number.
1678         target = dict([(i,2) for i in range(10)]) # seqnum3
1679         target[0] = 3 # seqnum4
1680         self._set_versions(target)
1681
1682         def _modify(oldversion, servermap, first_time):
1683             return oldversion + " modified"
1684         d = self._fn.modify(_modify)
1685         d.addCallback(lambda res: self._fn.download_best_version())
1686         expected = self.CONTENTS[2] + " modified"
1687         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1688         # and the servermap should indicate that the outlier was replaced too
1689         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1690         def _check_smap(smap):
1691             self.failUnlessEqual(smap.highest_seqnum(), 5)
1692             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1693             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1694         d.addCallback(_check_smap)
1695         return d
1696
1697
1698 class Utils(unittest.TestCase):
1699     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1700         # we compare this against sets of integers
1701         x = set(range(x_start, x_start+x_length))
1702         y = set(range(y_start, y_start+y_length))
1703         should_be_inside = x.issubset(y)
1704         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1705                                                          y_start, y_length),
1706                              str((x_start, x_length, y_start, y_length)))
1707
1708     def test_cache_inside(self):
1709         c = ResponseCache()
1710         x_start = 10
1711         x_length = 5
1712         for y_start in range(8, 17):
1713             for y_length in range(8):
1714                 self._do_inside(c, x_start, x_length, y_start, y_length)
1715
1716     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1717         # we compare this against sets of integers
1718         x = set(range(x_start, x_start+x_length))
1719         y = set(range(y_start, y_start+y_length))
1720         overlap = bool(x.intersection(y))
1721         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1722                                                       y_start, y_length),
1723                              str((x_start, x_length, y_start, y_length)))
1724
1725     def test_cache_overlap(self):
1726         c = ResponseCache()
1727         x_start = 10
1728         x_length = 5
1729         for y_start in range(8, 17):
1730             for y_length in range(8):
1731                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1732
1733     def test_cache(self):
1734         c = ResponseCache()
1735         # xdata = base62.b2a(os.urandom(100))[:100]
1736         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1737         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1738         nope = (None, None)
1739         c.add("v1", 1, 0, xdata, "time0")
1740         c.add("v1", 1, 2000, ydata, "time1")
1741         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1742         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1743         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1744         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1745         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1747         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1748         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1749         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1750         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1752         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1753         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1754         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1755         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1756         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1757         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1758         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1759
1760         # optional: join fragments
1761         c = ResponseCache()
1762         c.add("v1", 1, 0, xdata[:10], "time0")
1763         c.add("v1", 1, 10, xdata[10:20], "time1")
1764         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1765
1766 class Exceptions(unittest.TestCase):
1767     def test_repr(self):
1768         nmde = NeedMoreDataError(100, 50, 100)
1769         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1770         ucwe = UncoordinatedWriteError()
1771         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1772
1773 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1774 # instances which always succeed. So we need a less-fake one.
1775
1776 class IntentionalError(Exception):
1777     pass
1778
1779 class LocalWrapper:
1780     def __init__(self, original):
1781         self.original = original
1782         self.broken = False
1783         self.post_call_notifier = None
1784     def callRemote(self, methname, *args, **kwargs):
1785         def _call():
1786             if self.broken:
1787                 raise IntentionalError("I was asked to break")
1788             meth = getattr(self.original, "remote_" + methname)
1789             return meth(*args, **kwargs)
1790         d = fireEventually()
1791         d.addCallback(lambda res: _call())
1792         if self.post_call_notifier:
1793             d.addCallback(self.post_call_notifier, methname)
1794         return d
1795
1796 class LessFakeClient(FakeClient):
1797
1798     def __init__(self, basedir, num_peers=10):
1799         self._num_peers = num_peers
1800         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1801                          for i in range(self._num_peers)]
1802         self._connections = {}
1803         for peerid in self._peerids:
1804             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1805             make_dirs(peerdir)
1806             ss = storage.StorageServer(peerdir)
1807             ss.setNodeID(peerid)
1808             lw = LocalWrapper(ss)
1809             self._connections[peerid] = lw
1810         self.nodeid = "fakenodeid"
1811
1812
1813 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1814     def test_publish_surprise(self):
1815         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1816         self.client = LessFakeClient(basedir)
1817         d = self.client.create_mutable_file("contents 1")
1818         def _created(n):
1819             d = defer.succeed(None)
1820             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1821             def _got_smap1(smap):
1822                 # stash the old state of the file
1823                 self.old_map = smap
1824             d.addCallback(_got_smap1)
1825             # then modify the file, leaving the old map untouched
1826             d.addCallback(lambda res: log.msg("starting winning write"))
1827             d.addCallback(lambda res: n.overwrite("contents 2"))
1828             # now attempt to modify the file with the old servermap. This
1829             # will look just like an uncoordinated write, in which every
1830             # single share got updated between our mapupdate and our publish
1831             d.addCallback(lambda res: log.msg("starting doomed write"))
1832             d.addCallback(lambda res:
1833                           self.shouldFail(UncoordinatedWriteError,
1834                                           "test_publish_surprise", None,
1835                                           n.upload,
1836                                           "contents 2a", self.old_map))
1837             return d
1838         d.addCallback(_created)
1839         return d
1840
1841     def test_retrieve_surprise(self):
1842         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1843         self.client = LessFakeClient(basedir)
1844         d = self.client.create_mutable_file("contents 1")
1845         def _created(n):
1846             d = defer.succeed(None)
1847             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1848             def _got_smap1(smap):
1849                 # stash the old state of the file
1850                 self.old_map = smap
1851             d.addCallback(_got_smap1)
1852             # then modify the file, leaving the old map untouched
1853             d.addCallback(lambda res: log.msg("starting winning write"))
1854             d.addCallback(lambda res: n.overwrite("contents 2"))
1855             # now attempt to retrieve the old version with the old servermap.
1856             # This will look like someone has changed the file since we
1857             # updated the servermap.
1858             d.addCallback(lambda res: n._cache._clear())
1859             d.addCallback(lambda res: log.msg("starting doomed read"))
1860             d.addCallback(lambda res:
1861                           self.shouldFail(NotEnoughSharesError,
1862                                           "test_retrieve_surprise",
1863                                           "ran out of peers: have 0 shares (k=3)",
1864                                           n.download_version,
1865                                           self.old_map,
1866                                           self.old_map.best_recoverable_version(),
1867                                           ))
1868             return d
1869         d.addCallback(_created)
1870         return d
1871
1872     def test_unexpected_shares(self):
1873         # upload the file, take a servermap, shut down one of the servers,
1874         # upload it again (causing shares to appear on a new server), then
1875         # upload using the old servermap. The last upload should fail with an
1876         # UncoordinatedWriteError, because of the shares that didn't appear
1877         # in the servermap.
1878         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1879         self.client = LessFakeClient(basedir)
1880         d = self.client.create_mutable_file("contents 1")
1881         def _created(n):
1882             d = defer.succeed(None)
1883             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1884             def _got_smap1(smap):
1885                 # stash the old state of the file
1886                 self.old_map = smap
1887                 # now shut down one of the servers
1888                 peer0 = list(smap.make_sharemap()[0])[0]
1889                 self.client._connections.pop(peer0)
1890                 # then modify the file, leaving the old map untouched
1891                 log.msg("starting winning write")
1892                 return n.overwrite("contents 2")
1893             d.addCallback(_got_smap1)
1894             # now attempt to modify the file with the old servermap. This
1895             # will look just like an uncoordinated write, in which every
1896             # single share got updated between our mapupdate and our publish
1897             d.addCallback(lambda res: log.msg("starting doomed write"))
1898             d.addCallback(lambda res:
1899                           self.shouldFail(UncoordinatedWriteError,
1900                                           "test_surprise", None,
1901                                           n.upload,
1902                                           "contents 2a", self.old_map))
1903             return d
1904         d.addCallback(_created)
1905         return d
1906
1907     def test_bad_server(self):
1908         # Break one server, then create the file: the initial publish should
1909         # complete with an alternate server. Breaking a second server should
1910         # not prevent an update from succeeding either.
1911         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1912         self.client = LessFakeClient(basedir, 20)
1913         # to make sure that one of the initial peers is broken, we have to
1914         # get creative. We create the keys, so we can figure out the storage
1915         # index, but we hold off on doing the initial publish until we've
1916         # broken the server on which the first share wants to be stored.
1917         n = FastMutableFileNode(self.client)
1918         d = defer.succeed(None)
1919         d.addCallback(n._generate_pubprivkeys)
1920         d.addCallback(n._generated)
1921         def _break_peer0(res):
1922             si = n.get_storage_index()
1923             peerlist = self.client.get_permuted_peers("storage", si)
1924             peerid0, connection0 = peerlist[0]
1925             peerid1, connection1 = peerlist[1]
1926             connection0.broken = True
1927             self.connection1 = connection1
1928         d.addCallback(_break_peer0)
1929         # now let the initial publish finally happen
1930         d.addCallback(lambda res: n._upload("contents 1", None))
1931         # that ought to work
1932         d.addCallback(lambda res: n.download_best_version())
1933         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1934         # now break the second peer
1935         def _break_peer1(res):
1936             self.connection1.broken = True
1937         d.addCallback(_break_peer1)
1938         d.addCallback(lambda res: n.overwrite("contents 2"))
1939         # that ought to work too
1940         d.addCallback(lambda res: n.download_best_version())
1941         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1942         return d
1943
1944     def test_bad_server_overlap(self):
1945         # like test_bad_server, but with no extra unused servers to fall back
1946         # upon. This means that we must re-use a server which we've already
1947         # used. If we don't remember the fact that we sent them one share
1948         # already, we'll mistakenly think we're experiencing an
1949         # UncoordinatedWriteError.
1950
1951         # Break one server, then create the file: the initial publish should
1952         # complete with an alternate server. Breaking a second server should
1953         # not prevent an update from succeeding either.
1954         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1955         self.client = LessFakeClient(basedir, 10)
1956
1957         peerids = sorted(self.client._connections.keys())
1958         self.client._connections[peerids[0]].broken = True
1959
1960         d = self.client.create_mutable_file("contents 1")
1961         def _created(n):
1962             d = n.download_best_version()
1963             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1964             # now break one of the remaining servers
1965             def _break_second_server(res):
1966                 self.client._connections[peerids[1]].broken = True
1967             d.addCallback(_break_second_server)
1968             d.addCallback(lambda res: n.overwrite("contents 2"))
1969             # that ought to work too
1970             d.addCallback(lambda res: n.download_best_version())
1971             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1972             return d
1973         d.addCallback(_created)
1974         return d
1975
1976     def test_publish_all_servers_bad(self):
1977         # Break all servers: the publish should fail
1978         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1979         self.client = LessFakeClient(basedir, 20)
1980         for connection in self.client._connections.values():
1981             connection.broken = True
1982         d = self.shouldFail(NotEnoughServersError,
1983                             "test_publish_all_servers_bad",
1984                             "Ran out of non-bad servers",
1985                             self.client.create_mutable_file, "contents")
1986         return d
1987
1988     def test_publish_no_servers(self):
1989         # no servers at all: the publish should fail
1990         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1991         self.client = LessFakeClient(basedir, 0)
1992         d = self.shouldFail(NotEnoughServersError,
1993                             "test_publish_no_servers",
1994                             "Ran out of non-bad servers",
1995                             self.client.create_mutable_file, "contents")
1996         return d
1997     test_publish_no_servers.timeout = 30
1998
1999
2000     def test_privkey_query_error(self):
2001         # when a servermap is updated with MODE_WRITE, it tries to get the
2002         # privkey. Something might go wrong during this query attempt.
2003         self.client = FakeClient(20)
2004         # we need some contents that are large enough to push the privkey out
2005         # of the early part of the file
2006         LARGE = "These are Larger contents" * 200 # about 5KB
2007         d = self.client.create_mutable_file(LARGE)
2008         def _created(n):
2009             self.uri = n.get_uri()
2010             self.n2 = self.client.create_node_from_uri(self.uri)
2011             # we start by doing a map update to figure out which is the first
2012             # server.
2013             return n.get_servermap(MODE_WRITE)
2014         d.addCallback(_created)
2015         d.addCallback(lambda res: fireEventually(res))
2016         def _got_smap1(smap):
2017             peer0 = list(smap.make_sharemap()[0])[0]
2018             # we tell the server to respond to this peer first, so that it
2019             # will be asked for the privkey first
2020             self.client._storage._sequence = [peer0]
2021             # now we make the peer fail their second query
2022             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2023         d.addCallback(_got_smap1)
2024         # now we update a servermap from a new node (which doesn't have the
2025         # privkey yet, forcing it to use a separate privkey query). Each
2026         # query response will trigger a privkey query, and since we're using
2027         # _sequence to make the peer0 response come back first, we'll send it
2028         # a privkey query first, and _sequence will again ensure that the
2029         # peer0 query will also come back before the others, and then
2030         # _special_answers will make sure that the query raises an exception.
2031         # The whole point of these hijinks is to exercise the code in
2032         # _privkey_query_failed. Note that the map-update will succeed, since
2033         # we'll just get a copy from one of the other shares.
2034         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2035         # Using FakeStorage._sequence means there will be read requests still
2036         # floating around.. wait for them to retire
2037         def _cancel_timer(res):
2038             if self.client._storage._pending_timer:
2039                 self.client._storage._pending_timer.cancel()
2040             return res
2041         d.addBoth(_cancel_timer)
2042         return d
2043
2044     def test_privkey_query_missing(self):
2045         # like test_privkey_query_error, but the shares are deleted by the
2046         # second query, instead of raising an exception.
2047         self.client = FakeClient(20)
2048         LARGE = "These are Larger contents" * 200 # about 5KB
2049         d = self.client.create_mutable_file(LARGE)
2050         def _created(n):
2051             self.uri = n.get_uri()
2052             self.n2 = self.client.create_node_from_uri(self.uri)
2053             return n.get_servermap(MODE_WRITE)
2054         d.addCallback(_created)
2055         d.addCallback(lambda res: fireEventually(res))
2056         def _got_smap1(smap):
2057             peer0 = list(smap.make_sharemap()[0])[0]
2058             self.client._storage._sequence = [peer0]
2059             self.client._storage._special_answers[peer0] = ["normal", "none"]
2060         d.addCallback(_got_smap1)
2061         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2062         def _cancel_timer(res):
2063             if self.client._storage._pending_timer:
2064                 self.client._storage._pending_timer.cancel()
2065             return res
2066         d.addBoth(_cancel_timer)
2067         return d