]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
start to factor server-connection-management into a distinct 'StorageServerFarmBroker...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri
8 from allmydata.storage.server import StorageServer
9 from allmydata.immutable import download
10 from allmydata.util import base32, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      FileTooLargeError, NotEnoughSharesError, IRepairResults
16 from allmydata.monitor import Monitor
17 from allmydata.test.common import ShouldFailMixin
18 from foolscap.api import eventually, fireEventually
19 from foolscap.logging import log
20 from allmydata.storage_client import StorageFarmBroker
21
22 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
23 from allmydata.mutable.common import ResponseCache, \
24      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
25      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
26      NotEnoughServersError, CorruptShareError
27 from allmydata.mutable.retrieve import Retrieve
28 from allmydata.mutable.publish import Publish
29 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
30 from allmydata.mutable.layout import unpack_header, unpack_share
31 from allmydata.mutable.repairer import MustForceRepairError
32
33 import common_util as testutil
34
35 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
36 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
37 # of this.
38
39 class FastMutableFileNode(MutableFileNode):
40     SIGNATURE_KEY_SIZE = 522
41
42 # this "FakeStorage" exists to put the share data in RAM and avoid using real
43 # network connections, both to speed up the tests and to reduce the amount of
44 # non-mutable.py code being exercised.
45
46 class FakeStorage:
47     # this class replaces the collection of storage servers, allowing the
48     # tests to examine and manipulate the published shares. It also lets us
49     # control the order in which read queries are answered, to exercise more
50     # of the error-handling code in Retrieve .
51     #
52     # Note that we ignore the storage index: this FakeStorage instance can
53     # only be used for a single storage index.
54
55
56     def __init__(self):
57         self._peers = {}
58         # _sequence is used to cause the responses to occur in a specific
59         # order. If it is in use, then we will defer queries instead of
60         # answering them right away, accumulating the Deferreds in a dict. We
61         # don't know exactly how many queries we'll get, so exactly one
62         # second after the first query arrives, we will release them all (in
63         # order).
64         self._sequence = None
65         self._pending = {}
66         self._pending_timer = None
67         self._special_answers = {}
68
69     def read(self, peerid, storage_index):
70         shares = self._peers.get(peerid, {})
71         if self._special_answers.get(peerid, []):
72             mode = self._special_answers[peerid].pop(0)
73             if mode == "fail":
74                 shares = failure.Failure(IntentionalError())
75             elif mode == "none":
76                 shares = {}
77             elif mode == "normal":
78                 pass
79         if self._sequence is None:
80             return defer.succeed(shares)
81         d = defer.Deferred()
82         if not self._pending:
83             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
84         self._pending[peerid] = (d, shares)
85         return d
86
87     def _fire_readers(self):
88         self._pending_timer = None
89         pending = self._pending
90         self._pending = {}
91         extra = []
92         for peerid in self._sequence:
93             if peerid in pending:
94                 d, shares = pending.pop(peerid)
95                 eventually(d.callback, shares)
96         for (d, shares) in pending.values():
97             eventually(d.callback, shares)
98
99     def write(self, peerid, storage_index, shnum, offset, data):
100         if peerid not in self._peers:
101             self._peers[peerid] = {}
102         shares = self._peers[peerid]
103         f = StringIO()
104         f.write(shares.get(shnum, ""))
105         f.seek(offset)
106         f.write(data)
107         shares[shnum] = f.getvalue()
108
109
110 class FakeStorageServer:
111     def __init__(self, peerid, storage):
112         self.peerid = peerid
113         self.storage = storage
114         self.queries = 0
115     def callRemote(self, methname, *args, **kwargs):
116         def _call():
117             meth = getattr(self, methname)
118             return meth(*args, **kwargs)
119         d = fireEventually()
120         d.addCallback(lambda res: _call())
121         return d
122     def callRemoteOnly(self, methname, *args, **kwargs):
123         d = self.callRemote(methname, *args, **kwargs)
124         d.addBoth(lambda ignore: None)
125         pass
126
127     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
128         pass
129
130     def slot_readv(self, storage_index, shnums, readv):
131         d = self.storage.read(self.peerid, storage_index)
132         def _read(shares):
133             response = {}
134             for shnum in shares:
135                 if shnums and shnum not in shnums:
136                     continue
137                 vector = response[shnum] = []
138                 for (offset, length) in readv:
139                     assert isinstance(offset, (int, long)), offset
140                     assert isinstance(length, (int, long)), length
141                     vector.append(shares[shnum][offset:offset+length])
142             return response
143         d.addCallback(_read)
144         return d
145
146     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
147                                         tw_vectors, read_vector):
148         # always-pass: parrot the test vectors back to them.
149         readv = {}
150         for shnum, (testv, writev, new_length) in tw_vectors.items():
151             for (offset, length, op, specimen) in testv:
152                 assert op in ("le", "eq", "ge")
153             # TODO: this isn't right, the read is controlled by read_vector,
154             # not by testv
155             readv[shnum] = [ specimen
156                              for (offset, length, op, specimen)
157                              in testv ]
158             for (offset, data) in writev:
159                 self.storage.write(self.peerid, storage_index, shnum,
160                                    offset, data)
161         answer = (True, readv)
162         return fireEventually(answer)
163
164
165 # our "FakeClient" has just enough functionality of the real Client to let
166 # the tests run.
167
168 class FakeClient:
169     mutable_file_node_class = FastMutableFileNode
170
171     def __init__(self, num_peers=10):
172         self._storage = FakeStorage()
173         self._num_peers = num_peers
174         peerids = [tagged_hash("peerid", "%d" % i)[:20]
175                    for i in range(self._num_peers)]
176         self.nodeid = "fakenodeid"
177         self.storage_broker = StorageFarmBroker()
178         for peerid in peerids:
179             fss = FakeStorageServer(peerid, self._storage)
180             self.storage_broker.add_server(peerid, fss)
181
182     def get_all_serverids(self):
183         return self.storage_broker.get_all_serverids()
184     def debug_break_connection(self, peerid):
185         self.storage_broker.servers[peerid].broken = True
186     def debug_remove_connection(self, peerid):
187         self.storage_broker.servers.pop(peerid)
188     def debug_get_connection(self, peerid):
189         return self.storage_broker.servers[peerid]
190
191     def get_encoding_parameters(self):
192         return {"k": 3, "n": 10}
193
194     def log(self, msg, **kw):
195         return log.msg(msg, **kw)
196
197     def get_renewal_secret(self):
198         return "I hereby permit you to renew my files"
199     def get_cancel_secret(self):
200         return "I hereby permit you to cancel my leases"
201
202     def create_mutable_file(self, contents=""):
203         n = self.mutable_file_node_class(self)
204         d = n.create(contents)
205         d.addCallback(lambda res: n)
206         return d
207
208     def get_history(self):
209         return None
210
211     def create_node_from_uri(self, u):
212         u = IURI(u)
213         assert IMutableFileURI.providedBy(u), u
214         res = self.mutable_file_node_class(self).init_from_uri(u)
215         return res
216
217     def upload(self, uploadable):
218         assert IUploadable.providedBy(uploadable)
219         d = uploadable.get_size()
220         d.addCallback(lambda length: uploadable.read(length))
221         #d.addCallback(self.create_mutable_file)
222         def _got_data(datav):
223             data = "".join(datav)
224             #newnode = FastMutableFileNode(self)
225             return uri.LiteralFileURI(data)
226         d.addCallback(_got_data)
227         return d
228
229
230 def flip_bit(original, byte_offset):
231     return (original[:byte_offset] +
232             chr(ord(original[byte_offset]) ^ 0x01) +
233             original[byte_offset+1:])
234
235 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
236     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
237     # list of shnums to corrupt.
238     for peerid in s._peers:
239         shares = s._peers[peerid]
240         for shnum in shares:
241             if (shnums_to_corrupt is not None
242                 and shnum not in shnums_to_corrupt):
243                 continue
244             data = shares[shnum]
245             (version,
246              seqnum,
247              root_hash,
248              IV,
249              k, N, segsize, datalen,
250              o) = unpack_header(data)
251             if isinstance(offset, tuple):
252                 offset1, offset2 = offset
253             else:
254                 offset1 = offset
255                 offset2 = 0
256             if offset1 == "pubkey":
257                 real_offset = 107
258             elif offset1 in o:
259                 real_offset = o[offset1]
260             else:
261                 real_offset = offset1
262             real_offset = int(real_offset) + offset2 + offset_offset
263             assert isinstance(real_offset, int), offset
264             shares[shnum] = flip_bit(data, real_offset)
265     return res
266
267 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
268     def setUp(self):
269         self.client = FakeClient()
270
271     def test_create(self):
272         d = self.client.create_mutable_file()
273         def _created(n):
274             self.failUnless(isinstance(n, FastMutableFileNode))
275             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
276             peer0 = sorted(self.client.get_all_serverids())[0]
277             shnums = self.client._storage._peers[peer0].keys()
278             self.failUnlessEqual(len(shnums), 1)
279         d.addCallback(_created)
280         return d
281
282     def test_serialize(self):
283         n = MutableFileNode(self.client)
284         calls = []
285         def _callback(*args, **kwargs):
286             self.failUnlessEqual(args, (4,) )
287             self.failUnlessEqual(kwargs, {"foo": 5})
288             calls.append(1)
289             return 6
290         d = n._do_serialized(_callback, 4, foo=5)
291         def _check_callback(res):
292             self.failUnlessEqual(res, 6)
293             self.failUnlessEqual(calls, [1])
294         d.addCallback(_check_callback)
295
296         def _errback():
297             raise ValueError("heya")
298         d.addCallback(lambda res:
299                       self.shouldFail(ValueError, "_check_errback", "heya",
300                                       n._do_serialized, _errback))
301         return d
302
303     def test_upload_and_download(self):
304         d = self.client.create_mutable_file()
305         def _created(n):
306             d = defer.succeed(None)
307             d.addCallback(lambda res: n.get_servermap(MODE_READ))
308             d.addCallback(lambda smap: smap.dump(StringIO()))
309             d.addCallback(lambda sio:
310                           self.failUnless("3-of-10" in sio.getvalue()))
311             d.addCallback(lambda res: n.overwrite("contents 1"))
312             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
313             d.addCallback(lambda res: n.download_best_version())
314             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
315             d.addCallback(lambda res: n.get_size_of_best_version())
316             d.addCallback(lambda size:
317                           self.failUnlessEqual(size, len("contents 1")))
318             d.addCallback(lambda res: n.overwrite("contents 2"))
319             d.addCallback(lambda res: n.download_best_version())
320             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
321             d.addCallback(lambda res: n.download(download.Data()))
322             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
323             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
324             d.addCallback(lambda smap: n.upload("contents 3", smap))
325             d.addCallback(lambda res: n.download_best_version())
326             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
327             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
328             d.addCallback(lambda smap:
329                           n.download_version(smap,
330                                              smap.best_recoverable_version()))
331             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
332             # test a file that is large enough to overcome the
333             # mapupdate-to-retrieve data caching (i.e. make the shares larger
334             # than the default readsize, which is 2000 bytes). A 15kB file
335             # will have 5kB shares.
336             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
337             d.addCallback(lambda res: n.download_best_version())
338             d.addCallback(lambda res:
339                           self.failUnlessEqual(res, "large size file" * 1000))
340             return d
341         d.addCallback(_created)
342         return d
343
344     def test_create_with_initial_contents(self):
345         d = self.client.create_mutable_file("contents 1")
346         def _created(n):
347             d = n.download_best_version()
348             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
349             d.addCallback(lambda res: n.overwrite("contents 2"))
350             d.addCallback(lambda res: n.download_best_version())
351             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
352             return d
353         d.addCallback(_created)
354         return d
355
356     def test_create_with_too_large_contents(self):
357         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
358         d = self.shouldFail(FileTooLargeError, "too_large",
359                             "SDMF is limited to one segment, and %d > %d" %
360                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
361                             self.client.create_mutable_file, BIG)
362         d.addCallback(lambda res: self.client.create_mutable_file("small"))
363         def _created(n):
364             return self.shouldFail(FileTooLargeError, "too_large_2",
365                                    "SDMF is limited to one segment, and %d > %d" %
366                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
367                                    n.overwrite, BIG)
368         d.addCallback(_created)
369         return d
370
371     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
372         d = n.get_servermap(MODE_READ)
373         d.addCallback(lambda servermap: servermap.best_recoverable_version())
374         d.addCallback(lambda verinfo:
375                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
376         return d
377
378     def test_modify(self):
379         def _modifier(old_contents, servermap, first_time):
380             return old_contents + "line2"
381         def _non_modifier(old_contents, servermap, first_time):
382             return old_contents
383         def _none_modifier(old_contents, servermap, first_time):
384             return None
385         def _error_modifier(old_contents, servermap, first_time):
386             raise ValueError("oops")
387         def _toobig_modifier(old_contents, servermap, first_time):
388             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
389         calls = []
390         def _ucw_error_modifier(old_contents, servermap, first_time):
391             # simulate an UncoordinatedWriteError once
392             calls.append(1)
393             if len(calls) <= 1:
394                 raise UncoordinatedWriteError("simulated")
395             return old_contents + "line3"
396         def _ucw_error_non_modifier(old_contents, servermap, first_time):
397             # simulate an UncoordinatedWriteError once, and don't actually
398             # modify the contents on subsequent invocations
399             calls.append(1)
400             if len(calls) <= 1:
401                 raise UncoordinatedWriteError("simulated")
402             return old_contents
403
404         d = self.client.create_mutable_file("line1")
405         def _created(n):
406             d = n.modify(_modifier)
407             d.addCallback(lambda res: n.download_best_version())
408             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
409             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
410
411             d.addCallback(lambda res: n.modify(_non_modifier))
412             d.addCallback(lambda res: n.download_best_version())
413             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
414             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
415
416             d.addCallback(lambda res: n.modify(_none_modifier))
417             d.addCallback(lambda res: n.download_best_version())
418             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
419             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
420
421             d.addCallback(lambda res:
422                           self.shouldFail(ValueError, "error_modifier", None,
423                                           n.modify, _error_modifier))
424             d.addCallback(lambda res: n.download_best_version())
425             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
426             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
427
428             d.addCallback(lambda res:
429                           self.shouldFail(FileTooLargeError, "toobig_modifier",
430                                           "SDMF is limited to one segment",
431                                           n.modify, _toobig_modifier))
432             d.addCallback(lambda res: n.download_best_version())
433             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
434             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
435
436             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
437             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
438             d.addCallback(lambda res: n.download_best_version())
439             d.addCallback(lambda res: self.failUnlessEqual(res,
440                                                            "line1line2line3"))
441             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
442
443             def _reset_ucw_error_modifier(res):
444                 calls[:] = []
445                 return res
446             d.addCallback(_reset_ucw_error_modifier)
447
448             # in practice, this n.modify call should publish twice: the first
449             # one gets a UCWE, the second does not. But our test jig (in
450             # which the modifier raises the UCWE) skips over the first one,
451             # so in this test there will be only one publish, and the seqnum
452             # will only be one larger than the previous test, not two (i.e. 4
453             # instead of 5).
454             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
455             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
456             d.addCallback(lambda res: n.download_best_version())
457             d.addCallback(lambda res: self.failUnlessEqual(res,
458                                                            "line1line2line3"))
459             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
460
461             return d
462         d.addCallback(_created)
463         return d
464
465     def test_modify_backoffer(self):
466         def _modifier(old_contents, servermap, first_time):
467             return old_contents + "line2"
468         calls = []
469         def _ucw_error_modifier(old_contents, servermap, first_time):
470             # simulate an UncoordinatedWriteError once
471             calls.append(1)
472             if len(calls) <= 1:
473                 raise UncoordinatedWriteError("simulated")
474             return old_contents + "line3"
475         def _always_ucw_error_modifier(old_contents, servermap, first_time):
476             raise UncoordinatedWriteError("simulated")
477         def _backoff_stopper(node, f):
478             return f
479         def _backoff_pauser(node, f):
480             d = defer.Deferred()
481             reactor.callLater(0.5, d.callback, None)
482             return d
483
484         # the give-up-er will hit its maximum retry count quickly
485         giveuper = BackoffAgent()
486         giveuper._delay = 0.1
487         giveuper.factor = 1
488
489         d = self.client.create_mutable_file("line1")
490         def _created(n):
491             d = n.modify(_modifier)
492             d.addCallback(lambda res: n.download_best_version())
493             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
494             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
495
496             d.addCallback(lambda res:
497                           self.shouldFail(UncoordinatedWriteError,
498                                           "_backoff_stopper", None,
499                                           n.modify, _ucw_error_modifier,
500                                           _backoff_stopper))
501             d.addCallback(lambda res: n.download_best_version())
502             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
503             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
504
505             def _reset_ucw_error_modifier(res):
506                 calls[:] = []
507                 return res
508             d.addCallback(_reset_ucw_error_modifier)
509             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
510                                                _backoff_pauser))
511             d.addCallback(lambda res: n.download_best_version())
512             d.addCallback(lambda res: self.failUnlessEqual(res,
513                                                            "line1line2line3"))
514             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
515
516             d.addCallback(lambda res:
517                           self.shouldFail(UncoordinatedWriteError,
518                                           "giveuper", None,
519                                           n.modify, _always_ucw_error_modifier,
520                                           giveuper.delay))
521             d.addCallback(lambda res: n.download_best_version())
522             d.addCallback(lambda res: self.failUnlessEqual(res,
523                                                            "line1line2line3"))
524             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
525
526             return d
527         d.addCallback(_created)
528         return d
529
530     def test_upload_and_download_full_size_keys(self):
531         self.client.mutable_file_node_class = MutableFileNode
532         d = self.client.create_mutable_file()
533         def _created(n):
534             d = defer.succeed(None)
535             d.addCallback(lambda res: n.get_servermap(MODE_READ))
536             d.addCallback(lambda smap: smap.dump(StringIO()))
537             d.addCallback(lambda sio:
538                           self.failUnless("3-of-10" in sio.getvalue()))
539             d.addCallback(lambda res: n.overwrite("contents 1"))
540             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
541             d.addCallback(lambda res: n.download_best_version())
542             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
543             d.addCallback(lambda res: n.overwrite("contents 2"))
544             d.addCallback(lambda res: n.download_best_version())
545             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
546             d.addCallback(lambda res: n.download(download.Data()))
547             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
548             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
549             d.addCallback(lambda smap: n.upload("contents 3", smap))
550             d.addCallback(lambda res: n.download_best_version())
551             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
552             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
553             d.addCallback(lambda smap:
554                           n.download_version(smap,
555                                              smap.best_recoverable_version()))
556             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
557             return d
558         d.addCallback(_created)
559         return d
560
561
562 class MakeShares(unittest.TestCase):
563     def test_encrypt(self):
564         c = FakeClient()
565         fn = FastMutableFileNode(c)
566         CONTENTS = "some initial contents"
567         d = fn.create(CONTENTS)
568         def _created(res):
569             p = Publish(fn, None)
570             p.salt = "SALT" * 4
571             p.readkey = "\x00" * 16
572             p.newdata = CONTENTS
573             p.required_shares = 3
574             p.total_shares = 10
575             p.setup_encoding_parameters()
576             return p._encrypt_and_encode()
577         d.addCallback(_created)
578         def _done(shares_and_shareids):
579             (shares, share_ids) = shares_and_shareids
580             self.failUnlessEqual(len(shares), 10)
581             for sh in shares:
582                 self.failUnless(isinstance(sh, str))
583                 self.failUnlessEqual(len(sh), 7)
584             self.failUnlessEqual(len(share_ids), 10)
585         d.addCallback(_done)
586         return d
587
588     def test_generate(self):
589         c = FakeClient()
590         fn = FastMutableFileNode(c)
591         CONTENTS = "some initial contents"
592         d = fn.create(CONTENTS)
593         def _created(res):
594             p = Publish(fn, None)
595             self._p = p
596             p.newdata = CONTENTS
597             p.required_shares = 3
598             p.total_shares = 10
599             p.setup_encoding_parameters()
600             p._new_seqnum = 3
601             p.salt = "SALT" * 4
602             # make some fake shares
603             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
604             p._privkey = fn.get_privkey()
605             p._encprivkey = fn.get_encprivkey()
606             p._pubkey = fn.get_pubkey()
607             return p._generate_shares(shares_and_ids)
608         d.addCallback(_created)
609         def _generated(res):
610             p = self._p
611             final_shares = p.shares
612             root_hash = p.root_hash
613             self.failUnlessEqual(len(root_hash), 32)
614             self.failUnless(isinstance(final_shares, dict))
615             self.failUnlessEqual(len(final_shares), 10)
616             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
617             for i,sh in final_shares.items():
618                 self.failUnless(isinstance(sh, str))
619                 # feed the share through the unpacker as a sanity-check
620                 pieces = unpack_share(sh)
621                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
622                  pubkey, signature, share_hash_chain, block_hash_tree,
623                  share_data, enc_privkey) = pieces
624                 self.failUnlessEqual(u_seqnum, 3)
625                 self.failUnlessEqual(u_root_hash, root_hash)
626                 self.failUnlessEqual(k, 3)
627                 self.failUnlessEqual(N, 10)
628                 self.failUnlessEqual(segsize, 21)
629                 self.failUnlessEqual(datalen, len(CONTENTS))
630                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
631                 sig_material = struct.pack(">BQ32s16s BBQQ",
632                                            0, p._new_seqnum, root_hash, IV,
633                                            k, N, segsize, datalen)
634                 self.failUnless(p._pubkey.verify(sig_material, signature))
635                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
636                 self.failUnless(isinstance(share_hash_chain, dict))
637                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
638                 for shnum,share_hash in share_hash_chain.items():
639                     self.failUnless(isinstance(shnum, int))
640                     self.failUnless(isinstance(share_hash, str))
641                     self.failUnlessEqual(len(share_hash), 32)
642                 self.failUnless(isinstance(block_hash_tree, list))
643                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
644                 self.failUnlessEqual(IV, "SALT"*4)
645                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
646                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
647         d.addCallback(_generated)
648         return d
649
650     # TODO: when we publish to 20 peers, we should get one share per peer on 10
651     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
652     # when we publish to zero peers, we should get a NotEnoughSharesError
653
654 class PublishMixin:
655     def publish_one(self):
656         # publish a file and create shares, which can then be manipulated
657         # later.
658         self.CONTENTS = "New contents go here" * 1000
659         num_peers = 20
660         self._client = FakeClient(num_peers)
661         self._storage = self._client._storage
662         d = self._client.create_mutable_file(self.CONTENTS)
663         def _created(node):
664             self._fn = node
665             self._fn2 = self._client.create_node_from_uri(node.get_uri())
666         d.addCallback(_created)
667         return d
668     def publish_multiple(self):
669         self.CONTENTS = ["Contents 0",
670                          "Contents 1",
671                          "Contents 2",
672                          "Contents 3a",
673                          "Contents 3b"]
674         self._copied_shares = {}
675         num_peers = 20
676         self._client = FakeClient(num_peers)
677         self._storage = self._client._storage
678         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
679         def _created(node):
680             self._fn = node
681             # now create multiple versions of the same file, and accumulate
682             # their shares, so we can mix and match them later.
683             d = defer.succeed(None)
684             d.addCallback(self._copy_shares, 0)
685             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
686             d.addCallback(self._copy_shares, 1)
687             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
688             d.addCallback(self._copy_shares, 2)
689             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
690             d.addCallback(self._copy_shares, 3)
691             # now we replace all the shares with version s3, and upload a new
692             # version to get s4b.
693             rollback = dict([(i,2) for i in range(10)])
694             d.addCallback(lambda res: self._set_versions(rollback))
695             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
696             d.addCallback(self._copy_shares, 4)
697             # we leave the storage in state 4
698             return d
699         d.addCallback(_created)
700         return d
701
702     def _copy_shares(self, ignored, index):
703         shares = self._client._storage._peers
704         # we need a deep copy
705         new_shares = {}
706         for peerid in shares:
707             new_shares[peerid] = {}
708             for shnum in shares[peerid]:
709                 new_shares[peerid][shnum] = shares[peerid][shnum]
710         self._copied_shares[index] = new_shares
711
712     def _set_versions(self, versionmap):
713         # versionmap maps shnums to which version (0,1,2,3,4) we want the
714         # share to be at. Any shnum which is left out of the map will stay at
715         # its current version.
716         shares = self._client._storage._peers
717         oldshares = self._copied_shares
718         for peerid in shares:
719             for shnum in shares[peerid]:
720                 if shnum in versionmap:
721                     index = versionmap[shnum]
722                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
723
724
725 class Servermap(unittest.TestCase, PublishMixin):
726     def setUp(self):
727         return self.publish_one()
728
729     def make_servermap(self, mode=MODE_CHECK, fn=None):
730         if fn is None:
731             fn = self._fn
732         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
733         d = smu.update()
734         return d
735
736     def update_servermap(self, oldmap, mode=MODE_CHECK):
737         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
738         d = smu.update()
739         return d
740
741     def failUnlessOneRecoverable(self, sm, num_shares):
742         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
743         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
744         best = sm.best_recoverable_version()
745         self.failIfEqual(best, None)
746         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
747         self.failUnlessEqual(len(sm.shares_available()), 1)
748         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
749         shnum, peerids = sm.make_sharemap().items()[0]
750         peerid = list(peerids)[0]
751         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
752         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
753         return sm
754
755     def test_basic(self):
756         d = defer.succeed(None)
757         ms = self.make_servermap
758         us = self.update_servermap
759
760         d.addCallback(lambda res: ms(mode=MODE_CHECK))
761         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
762         d.addCallback(lambda res: ms(mode=MODE_WRITE))
763         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
764         d.addCallback(lambda res: ms(mode=MODE_READ))
765         # this more stops at k+epsilon, and epsilon=k, so 6 shares
766         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
767         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
768         # this mode stops at 'k' shares
769         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
770
771         # and can we re-use the same servermap? Note that these are sorted in
772         # increasing order of number of servers queried, since once a server
773         # gets into the servermap, we'll always ask it for an update.
774         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
775         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
777         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
778         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
779         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
780         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
781         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
782         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
783
784         return d
785
786     def test_fetch_privkey(self):
787         d = defer.succeed(None)
788         # use the sibling filenode (which hasn't been used yet), and make
789         # sure it can fetch the privkey. The file is small, so the privkey
790         # will be fetched on the first (query) pass.
791         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
792         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
793
794         # create a new file, which is large enough to knock the privkey out
795         # of the early part of the file
796         LARGE = "These are Larger contents" * 200 # about 5KB
797         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
798         def _created(large_fn):
799             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
800             return self.make_servermap(MODE_WRITE, large_fn2)
801         d.addCallback(_created)
802         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
803         return d
804
805     def test_mark_bad(self):
806         d = defer.succeed(None)
807         ms = self.make_servermap
808         us = self.update_servermap
809
810         d.addCallback(lambda res: ms(mode=MODE_READ))
811         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
812         def _made_map(sm):
813             v = sm.best_recoverable_version()
814             vm = sm.make_versionmap()
815             shares = list(vm[v])
816             self.failUnlessEqual(len(shares), 6)
817             self._corrupted = set()
818             # mark the first 5 shares as corrupt, then update the servermap.
819             # The map should not have the marked shares it in any more, and
820             # new shares should be found to replace the missing ones.
821             for (shnum, peerid, timestamp) in shares:
822                 if shnum < 5:
823                     self._corrupted.add( (peerid, shnum) )
824                     sm.mark_bad_share(peerid, shnum, "")
825             return self.update_servermap(sm, MODE_WRITE)
826         d.addCallback(_made_map)
827         def _check_map(sm):
828             # this should find all 5 shares that weren't marked bad
829             v = sm.best_recoverable_version()
830             vm = sm.make_versionmap()
831             shares = list(vm[v])
832             for (peerid, shnum) in self._corrupted:
833                 peer_shares = sm.shares_on_peer(peerid)
834                 self.failIf(shnum in peer_shares,
835                             "%d was in %s" % (shnum, peer_shares))
836             self.failUnlessEqual(len(shares), 5)
837         d.addCallback(_check_map)
838         return d
839
840     def failUnlessNoneRecoverable(self, sm):
841         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
842         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
843         best = sm.best_recoverable_version()
844         self.failUnlessEqual(best, None)
845         self.failUnlessEqual(len(sm.shares_available()), 0)
846
847     def test_no_shares(self):
848         self._client._storage._peers = {} # delete all shares
849         ms = self.make_servermap
850         d = defer.succeed(None)
851
852         d.addCallback(lambda res: ms(mode=MODE_CHECK))
853         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
854
855         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
856         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
857
858         d.addCallback(lambda res: ms(mode=MODE_WRITE))
859         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
860
861         d.addCallback(lambda res: ms(mode=MODE_READ))
862         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
863
864         return d
865
866     def failUnlessNotQuiteEnough(self, sm):
867         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
868         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
869         best = sm.best_recoverable_version()
870         self.failUnlessEqual(best, None)
871         self.failUnlessEqual(len(sm.shares_available()), 1)
872         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
873         return sm
874
875     def test_not_quite_enough_shares(self):
876         s = self._client._storage
877         ms = self.make_servermap
878         num_shares = len(s._peers)
879         for peerid in s._peers:
880             s._peers[peerid] = {}
881             num_shares -= 1
882             if num_shares == 2:
883                 break
884         # now there ought to be only two shares left
885         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
886
887         d = defer.succeed(None)
888
889         d.addCallback(lambda res: ms(mode=MODE_CHECK))
890         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
891         d.addCallback(lambda sm:
892                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
893         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
894         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
895         d.addCallback(lambda res: ms(mode=MODE_WRITE))
896         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
897         d.addCallback(lambda res: ms(mode=MODE_READ))
898         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
899
900         return d
901
902
903
904 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
905     def setUp(self):
906         return self.publish_one()
907
908     def make_servermap(self, mode=MODE_READ, oldmap=None):
909         if oldmap is None:
910             oldmap = ServerMap()
911         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
912         d = smu.update()
913         return d
914
915     def abbrev_verinfo(self, verinfo):
916         if verinfo is None:
917             return None
918         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
919          offsets_tuple) = verinfo
920         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
921
922     def abbrev_verinfo_dict(self, verinfo_d):
923         output = {}
924         for verinfo,value in verinfo_d.items():
925             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
926              offsets_tuple) = verinfo
927             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
928         return output
929
930     def dump_servermap(self, servermap):
931         print "SERVERMAP", servermap
932         print "RECOVERABLE", [self.abbrev_verinfo(v)
933                               for v in servermap.recoverable_versions()]
934         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
935         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
936
937     def do_download(self, servermap, version=None):
938         if version is None:
939             version = servermap.best_recoverable_version()
940         r = Retrieve(self._fn, servermap, version)
941         return r.download()
942
943     def test_basic(self):
944         d = self.make_servermap()
945         def _do_retrieve(servermap):
946             self._smap = servermap
947             #self.dump_servermap(servermap)
948             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
949             return self.do_download(servermap)
950         d.addCallback(_do_retrieve)
951         def _retrieved(new_contents):
952             self.failUnlessEqual(new_contents, self.CONTENTS)
953         d.addCallback(_retrieved)
954         # we should be able to re-use the same servermap, both with and
955         # without updating it.
956         d.addCallback(lambda res: self.do_download(self._smap))
957         d.addCallback(_retrieved)
958         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
959         d.addCallback(lambda res: self.do_download(self._smap))
960         d.addCallback(_retrieved)
961         # clobbering the pubkey should make the servermap updater re-fetch it
962         def _clobber_pubkey(res):
963             self._fn._pubkey = None
964         d.addCallback(_clobber_pubkey)
965         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
966         d.addCallback(lambda res: self.do_download(self._smap))
967         d.addCallback(_retrieved)
968         return d
969
970     def test_all_shares_vanished(self):
971         d = self.make_servermap()
972         def _remove_shares(servermap):
973             for shares in self._storage._peers.values():
974                 shares.clear()
975             d1 = self.shouldFail(NotEnoughSharesError,
976                                  "test_all_shares_vanished",
977                                  "ran out of peers",
978                                  self.do_download, servermap)
979             return d1
980         d.addCallback(_remove_shares)
981         return d
982
983     def test_no_servers(self):
984         c2 = FakeClient(0)
985         self._fn._client = c2
986         # if there are no servers, then a MODE_READ servermap should come
987         # back empty
988         d = self.make_servermap()
989         def _check_servermap(servermap):
990             self.failUnlessEqual(servermap.best_recoverable_version(), None)
991             self.failIf(servermap.recoverable_versions())
992             self.failIf(servermap.unrecoverable_versions())
993             self.failIf(servermap.all_peers())
994         d.addCallback(_check_servermap)
995         return d
996     test_no_servers.timeout = 15
997
998     def test_no_servers_download(self):
999         c2 = FakeClient(0)
1000         self._fn._client = c2
1001         d = self.shouldFail(UnrecoverableFileError,
1002                             "test_no_servers_download",
1003                             "no recoverable versions",
1004                             self._fn.download_best_version)
1005         def _restore(res):
1006             # a failed download that occurs while we aren't connected to
1007             # anybody should not prevent a subsequent download from working.
1008             # This isn't quite the webapi-driven test that #463 wants, but it
1009             # should be close enough.
1010             self._fn._client = self._client
1011             return self._fn.download_best_version()
1012         def _retrieved(new_contents):
1013             self.failUnlessEqual(new_contents, self.CONTENTS)
1014         d.addCallback(_restore)
1015         d.addCallback(_retrieved)
1016         return d
1017     test_no_servers_download.timeout = 15
1018
1019     def _test_corrupt_all(self, offset, substring,
1020                           should_succeed=False, corrupt_early=True,
1021                           failure_checker=None):
1022         d = defer.succeed(None)
1023         if corrupt_early:
1024             d.addCallback(corrupt, self._storage, offset)
1025         d.addCallback(lambda res: self.make_servermap())
1026         if not corrupt_early:
1027             d.addCallback(corrupt, self._storage, offset)
1028         def _do_retrieve(servermap):
1029             ver = servermap.best_recoverable_version()
1030             if ver is None and not should_succeed:
1031                 # no recoverable versions == not succeeding. The problem
1032                 # should be noted in the servermap's list of problems.
1033                 if substring:
1034                     allproblems = [str(f) for f in servermap.problems]
1035                     self.failUnless(substring in "".join(allproblems))
1036                 return servermap
1037             if should_succeed:
1038                 d1 = self._fn.download_version(servermap, ver)
1039                 d1.addCallback(lambda new_contents:
1040                                self.failUnlessEqual(new_contents, self.CONTENTS))
1041             else:
1042                 d1 = self.shouldFail(NotEnoughSharesError,
1043                                      "_corrupt_all(offset=%s)" % (offset,),
1044                                      substring,
1045                                      self._fn.download_version, servermap, ver)
1046             if failure_checker:
1047                 d1.addCallback(failure_checker)
1048             d1.addCallback(lambda res: servermap)
1049             return d1
1050         d.addCallback(_do_retrieve)
1051         return d
1052
1053     def test_corrupt_all_verbyte(self):
1054         # when the version byte is not 0, we hit an assertion error in
1055         # unpack_share().
1056         d = self._test_corrupt_all(0, "AssertionError")
1057         def _check_servermap(servermap):
1058             # and the dump should mention the problems
1059             s = StringIO()
1060             dump = servermap.dump(s).getvalue()
1061             self.failUnless("10 PROBLEMS" in dump, dump)
1062         d.addCallback(_check_servermap)
1063         return d
1064
1065     def test_corrupt_all_seqnum(self):
1066         # a corrupt sequence number will trigger a bad signature
1067         return self._test_corrupt_all(1, "signature is invalid")
1068
1069     def test_corrupt_all_R(self):
1070         # a corrupt root hash will trigger a bad signature
1071         return self._test_corrupt_all(9, "signature is invalid")
1072
1073     def test_corrupt_all_IV(self):
1074         # a corrupt salt/IV will trigger a bad signature
1075         return self._test_corrupt_all(41, "signature is invalid")
1076
1077     def test_corrupt_all_k(self):
1078         # a corrupt 'k' will trigger a bad signature
1079         return self._test_corrupt_all(57, "signature is invalid")
1080
1081     def test_corrupt_all_N(self):
1082         # a corrupt 'N' will trigger a bad signature
1083         return self._test_corrupt_all(58, "signature is invalid")
1084
1085     def test_corrupt_all_segsize(self):
1086         # a corrupt segsize will trigger a bad signature
1087         return self._test_corrupt_all(59, "signature is invalid")
1088
1089     def test_corrupt_all_datalen(self):
1090         # a corrupt data length will trigger a bad signature
1091         return self._test_corrupt_all(67, "signature is invalid")
1092
1093     def test_corrupt_all_pubkey(self):
1094         # a corrupt pubkey won't match the URI's fingerprint. We need to
1095         # remove the pubkey from the filenode, or else it won't bother trying
1096         # to update it.
1097         self._fn._pubkey = None
1098         return self._test_corrupt_all("pubkey",
1099                                       "pubkey doesn't match fingerprint")
1100
1101     def test_corrupt_all_sig(self):
1102         # a corrupt signature is a bad one
1103         # the signature runs from about [543:799], depending upon the length
1104         # of the pubkey
1105         return self._test_corrupt_all("signature", "signature is invalid")
1106
1107     def test_corrupt_all_share_hash_chain_number(self):
1108         # a corrupt share hash chain entry will show up as a bad hash. If we
1109         # mangle the first byte, that will look like a bad hash number,
1110         # causing an IndexError
1111         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1112
1113     def test_corrupt_all_share_hash_chain_hash(self):
1114         # a corrupt share hash chain entry will show up as a bad hash. If we
1115         # mangle a few bytes in, that will look like a bad hash.
1116         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1117
1118     def test_corrupt_all_block_hash_tree(self):
1119         return self._test_corrupt_all("block_hash_tree",
1120                                       "block hash tree failure")
1121
1122     def test_corrupt_all_block(self):
1123         return self._test_corrupt_all("share_data", "block hash tree failure")
1124
1125     def test_corrupt_all_encprivkey(self):
1126         # a corrupted privkey won't even be noticed by the reader, only by a
1127         # writer.
1128         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1129
1130
1131     def test_corrupt_all_seqnum_late(self):
1132         # corrupting the seqnum between mapupdate and retrieve should result
1133         # in NotEnoughSharesError, since each share will look invalid
1134         def _check(res):
1135             f = res[0]
1136             self.failUnless(f.check(NotEnoughSharesError))
1137             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1138         return self._test_corrupt_all(1, "ran out of peers",
1139                                       corrupt_early=False,
1140                                       failure_checker=_check)
1141
1142     def test_corrupt_all_block_hash_tree_late(self):
1143         def _check(res):
1144             f = res[0]
1145             self.failUnless(f.check(NotEnoughSharesError))
1146         return self._test_corrupt_all("block_hash_tree",
1147                                       "block hash tree failure",
1148                                       corrupt_early=False,
1149                                       failure_checker=_check)
1150
1151
1152     def test_corrupt_all_block_late(self):
1153         def _check(res):
1154             f = res[0]
1155             self.failUnless(f.check(NotEnoughSharesError))
1156         return self._test_corrupt_all("share_data", "block hash tree failure",
1157                                       corrupt_early=False,
1158                                       failure_checker=_check)
1159
1160
1161     def test_basic_pubkey_at_end(self):
1162         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1163         # download to succeed but forcing a bunch of retries first. Note that
1164         # this is rather pessimistic: our Retrieve process will throw away
1165         # the whole share if the pubkey is bad, even though the rest of the
1166         # share might be good.
1167
1168         self._fn._pubkey = None
1169         k = self._fn.get_required_shares()
1170         N = self._fn.get_total_shares()
1171         d = defer.succeed(None)
1172         d.addCallback(corrupt, self._storage, "pubkey",
1173                       shnums_to_corrupt=range(0, N-k))
1174         d.addCallback(lambda res: self.make_servermap())
1175         def _do_retrieve(servermap):
1176             self.failUnless(servermap.problems)
1177             self.failUnless("pubkey doesn't match fingerprint"
1178                             in str(servermap.problems[0]))
1179             ver = servermap.best_recoverable_version()
1180             r = Retrieve(self._fn, servermap, ver)
1181             return r.download()
1182         d.addCallback(_do_retrieve)
1183         d.addCallback(lambda new_contents:
1184                       self.failUnlessEqual(new_contents, self.CONTENTS))
1185         return d
1186
1187     def test_corrupt_some(self):
1188         # corrupt the data of first five shares (so the servermap thinks
1189         # they're good but retrieve marks them as bad), so that the
1190         # MODE_READ set of 6 will be insufficient, forcing node.download to
1191         # retry with more servers.
1192         corrupt(None, self._storage, "share_data", range(5))
1193         d = self.make_servermap()
1194         def _do_retrieve(servermap):
1195             ver = servermap.best_recoverable_version()
1196             self.failUnless(ver)
1197             return self._fn.download_best_version()
1198         d.addCallback(_do_retrieve)
1199         d.addCallback(lambda new_contents:
1200                       self.failUnlessEqual(new_contents, self.CONTENTS))
1201         return d
1202
1203     def test_download_fails(self):
1204         corrupt(None, self._storage, "signature")
1205         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1206                             "no recoverable versions",
1207                             self._fn.download_best_version)
1208         return d
1209
1210
1211 class CheckerMixin:
1212     def check_good(self, r, where):
1213         self.failUnless(r.is_healthy(), where)
1214         return r
1215
1216     def check_bad(self, r, where):
1217         self.failIf(r.is_healthy(), where)
1218         return r
1219
1220     def check_expected_failure(self, r, expected_exception, substring, where):
1221         for (peerid, storage_index, shnum, f) in r.problems:
1222             if f.check(expected_exception):
1223                 self.failUnless(substring in str(f),
1224                                 "%s: substring '%s' not in '%s'" %
1225                                 (where, substring, str(f)))
1226                 return
1227         self.fail("%s: didn't see expected exception %s in problems %s" %
1228                   (where, expected_exception, r.problems))
1229
1230
1231 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1232     def setUp(self):
1233         return self.publish_one()
1234
1235
1236     def test_check_good(self):
1237         d = self._fn.check(Monitor())
1238         d.addCallback(self.check_good, "test_check_good")
1239         return d
1240
1241     def test_check_no_shares(self):
1242         for shares in self._storage._peers.values():
1243             shares.clear()
1244         d = self._fn.check(Monitor())
1245         d.addCallback(self.check_bad, "test_check_no_shares")
1246         return d
1247
1248     def test_check_not_enough_shares(self):
1249         for shares in self._storage._peers.values():
1250             for shnum in shares.keys():
1251                 if shnum > 0:
1252                     del shares[shnum]
1253         d = self._fn.check(Monitor())
1254         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1255         return d
1256
1257     def test_check_all_bad_sig(self):
1258         corrupt(None, self._storage, 1) # bad sig
1259         d = self._fn.check(Monitor())
1260         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1261         return d
1262
1263     def test_check_all_bad_blocks(self):
1264         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1265         # the Checker won't notice this.. it doesn't look at actual data
1266         d = self._fn.check(Monitor())
1267         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1268         return d
1269
1270     def test_verify_good(self):
1271         d = self._fn.check(Monitor(), verify=True)
1272         d.addCallback(self.check_good, "test_verify_good")
1273         return d
1274
1275     def test_verify_all_bad_sig(self):
1276         corrupt(None, self._storage, 1) # bad sig
1277         d = self._fn.check(Monitor(), verify=True)
1278         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1279         return d
1280
1281     def test_verify_one_bad_sig(self):
1282         corrupt(None, self._storage, 1, [9]) # bad sig
1283         d = self._fn.check(Monitor(), verify=True)
1284         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1285         return d
1286
1287     def test_verify_one_bad_block(self):
1288         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1289         # the Verifier *will* notice this, since it examines every byte
1290         d = self._fn.check(Monitor(), verify=True)
1291         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1292         d.addCallback(self.check_expected_failure,
1293                       CorruptShareError, "block hash tree failure",
1294                       "test_verify_one_bad_block")
1295         return d
1296
1297     def test_verify_one_bad_sharehash(self):
1298         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1299         d = self._fn.check(Monitor(), verify=True)
1300         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1301         d.addCallback(self.check_expected_failure,
1302                       CorruptShareError, "corrupt hashes",
1303                       "test_verify_one_bad_sharehash")
1304         return d
1305
1306     def test_verify_one_bad_encprivkey(self):
1307         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1308         d = self._fn.check(Monitor(), verify=True)
1309         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1310         d.addCallback(self.check_expected_failure,
1311                       CorruptShareError, "invalid privkey",
1312                       "test_verify_one_bad_encprivkey")
1313         return d
1314
1315     def test_verify_one_bad_encprivkey_uncheckable(self):
1316         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1317         readonly_fn = self._fn.get_readonly()
1318         # a read-only node has no way to validate the privkey
1319         d = readonly_fn.check(Monitor(), verify=True)
1320         d.addCallback(self.check_good,
1321                       "test_verify_one_bad_encprivkey_uncheckable")
1322         return d
1323
1324 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1325
1326     def get_shares(self, s):
1327         all_shares = {} # maps (peerid, shnum) to share data
1328         for peerid in s._peers:
1329             shares = s._peers[peerid]
1330             for shnum in shares:
1331                 data = shares[shnum]
1332                 all_shares[ (peerid, shnum) ] = data
1333         return all_shares
1334
1335     def copy_shares(self, ignored=None):
1336         self.old_shares.append(self.get_shares(self._storage))
1337
1338     def test_repair_nop(self):
1339         self.old_shares = []
1340         d = self.publish_one()
1341         d.addCallback(self.copy_shares)
1342         d.addCallback(lambda res: self._fn.check(Monitor()))
1343         d.addCallback(lambda check_results: self._fn.repair(check_results))
1344         def _check_results(rres):
1345             self.failUnless(IRepairResults.providedBy(rres))
1346             # TODO: examine results
1347
1348             self.copy_shares()
1349
1350             initial_shares = self.old_shares[0]
1351             new_shares = self.old_shares[1]
1352             # TODO: this really shouldn't change anything. When we implement
1353             # a "minimal-bandwidth" repairer", change this test to assert:
1354             #self.failUnlessEqual(new_shares, initial_shares)
1355
1356             # all shares should be in the same place as before
1357             self.failUnlessEqual(set(initial_shares.keys()),
1358                                  set(new_shares.keys()))
1359             # but they should all be at a newer seqnum. The IV will be
1360             # different, so the roothash will be too.
1361             for key in initial_shares:
1362                 (version0,
1363                  seqnum0,
1364                  root_hash0,
1365                  IV0,
1366                  k0, N0, segsize0, datalen0,
1367                  o0) = unpack_header(initial_shares[key])
1368                 (version1,
1369                  seqnum1,
1370                  root_hash1,
1371                  IV1,
1372                  k1, N1, segsize1, datalen1,
1373                  o1) = unpack_header(new_shares[key])
1374                 self.failUnlessEqual(version0, version1)
1375                 self.failUnlessEqual(seqnum0+1, seqnum1)
1376                 self.failUnlessEqual(k0, k1)
1377                 self.failUnlessEqual(N0, N1)
1378                 self.failUnlessEqual(segsize0, segsize1)
1379                 self.failUnlessEqual(datalen0, datalen1)
1380         d.addCallback(_check_results)
1381         return d
1382
1383     def failIfSharesChanged(self, ignored=None):
1384         old_shares = self.old_shares[-2]
1385         current_shares = self.old_shares[-1]
1386         self.failUnlessEqual(old_shares, current_shares)
1387
1388     def test_merge(self):
1389         self.old_shares = []
1390         d = self.publish_multiple()
1391         # repair will refuse to merge multiple highest seqnums unless you
1392         # pass force=True
1393         d.addCallback(lambda res:
1394                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1395                                           1:4,3:4,5:4,7:4,9:4}))
1396         d.addCallback(self.copy_shares)
1397         d.addCallback(lambda res: self._fn.check(Monitor()))
1398         def _try_repair(check_results):
1399             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1400             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1401                                  self._fn.repair, check_results)
1402             d2.addCallback(self.copy_shares)
1403             d2.addCallback(self.failIfSharesChanged)
1404             d2.addCallback(lambda res: check_results)
1405             return d2
1406         d.addCallback(_try_repair)
1407         d.addCallback(lambda check_results:
1408                       self._fn.repair(check_results, force=True))
1409         # this should give us 10 shares of the highest roothash
1410         def _check_repair_results(rres):
1411             pass # TODO
1412         d.addCallback(_check_repair_results)
1413         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1414         def _check_smap(smap):
1415             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1416             self.failIf(smap.unrecoverable_versions())
1417             # now, which should have won?
1418             roothash_s4a = self.get_roothash_for(3)
1419             roothash_s4b = self.get_roothash_for(4)
1420             if roothash_s4b > roothash_s4a:
1421                 expected_contents = self.CONTENTS[4]
1422             else:
1423                 expected_contents = self.CONTENTS[3]
1424             new_versionid = smap.best_recoverable_version()
1425             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1426             d2 = self._fn.download_version(smap, new_versionid)
1427             d2.addCallback(self.failUnlessEqual, expected_contents)
1428             return d2
1429         d.addCallback(_check_smap)
1430         return d
1431
1432     def test_non_merge(self):
1433         self.old_shares = []
1434         d = self.publish_multiple()
1435         # repair should not refuse a repair that doesn't need to merge. In
1436         # this case, we combine v2 with v3. The repair should ignore v2 and
1437         # copy v3 into a new v5.
1438         d.addCallback(lambda res:
1439                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1440                                           1:3,3:3,5:3,7:3,9:3}))
1441         d.addCallback(lambda res: self._fn.check(Monitor()))
1442         d.addCallback(lambda check_results: self._fn.repair(check_results))
1443         # this should give us 10 shares of v3
1444         def _check_repair_results(rres):
1445             pass # TODO
1446         d.addCallback(_check_repair_results)
1447         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1448         def _check_smap(smap):
1449             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1450             self.failIf(smap.unrecoverable_versions())
1451             # now, which should have won?
1452             roothash_s4a = self.get_roothash_for(3)
1453             expected_contents = self.CONTENTS[3]
1454             new_versionid = smap.best_recoverable_version()
1455             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1456             d2 = self._fn.download_version(smap, new_versionid)
1457             d2.addCallback(self.failUnlessEqual, expected_contents)
1458             return d2
1459         d.addCallback(_check_smap)
1460         return d
1461
1462     def get_roothash_for(self, index):
1463         # return the roothash for the first share we see in the saved set
1464         shares = self._copied_shares[index]
1465         for peerid in shares:
1466             for shnum in shares[peerid]:
1467                 share = shares[peerid][shnum]
1468                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1469                           unpack_header(share)
1470                 return root_hash
1471
1472 class MultipleEncodings(unittest.TestCase):
1473     def setUp(self):
1474         self.CONTENTS = "New contents go here"
1475         num_peers = 20
1476         self._client = FakeClient(num_peers)
1477         self._storage = self._client._storage
1478         d = self._client.create_mutable_file(self.CONTENTS)
1479         def _created(node):
1480             self._fn = node
1481         d.addCallback(_created)
1482         return d
1483
1484     def _encode(self, k, n, data):
1485         # encode 'data' into a peerid->shares dict.
1486
1487         fn2 = FastMutableFileNode(self._client)
1488         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1489         # and _fingerprint
1490         fn = self._fn
1491         fn2.init_from_uri(fn.get_uri())
1492         # then we copy over other fields that are normally fetched from the
1493         # existing shares
1494         fn2._pubkey = fn._pubkey
1495         fn2._privkey = fn._privkey
1496         fn2._encprivkey = fn._encprivkey
1497         # and set the encoding parameters to something completely different
1498         fn2._required_shares = k
1499         fn2._total_shares = n
1500
1501         s = self._client._storage
1502         s._peers = {} # clear existing storage
1503         p2 = Publish(fn2, None)
1504         d = p2.publish(data)
1505         def _published(res):
1506             shares = s._peers
1507             s._peers = {}
1508             return shares
1509         d.addCallback(_published)
1510         return d
1511
1512     def make_servermap(self, mode=MODE_READ, oldmap=None):
1513         if oldmap is None:
1514             oldmap = ServerMap()
1515         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1516         d = smu.update()
1517         return d
1518
1519     def test_multiple_encodings(self):
1520         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1521         # then mix up the shares, to make sure that download survives seeing
1522         # a variety of encodings. This is actually kind of tricky to set up.
1523
1524         contents1 = "Contents for encoding 1 (3-of-10) go here"
1525         contents2 = "Contents for encoding 2 (4-of-9) go here"
1526         contents3 = "Contents for encoding 3 (4-of-7) go here"
1527
1528         # we make a retrieval object that doesn't know what encoding
1529         # parameters to use
1530         fn3 = FastMutableFileNode(self._client)
1531         fn3.init_from_uri(self._fn.get_uri())
1532
1533         # now we upload a file through fn1, and grab its shares
1534         d = self._encode(3, 10, contents1)
1535         def _encoded_1(shares):
1536             self._shares1 = shares
1537         d.addCallback(_encoded_1)
1538         d.addCallback(lambda res: self._encode(4, 9, contents2))
1539         def _encoded_2(shares):
1540             self._shares2 = shares
1541         d.addCallback(_encoded_2)
1542         d.addCallback(lambda res: self._encode(4, 7, contents3))
1543         def _encoded_3(shares):
1544             self._shares3 = shares
1545         d.addCallback(_encoded_3)
1546
1547         def _merge(res):
1548             log.msg("merging sharelists")
1549             # we merge the shares from the two sets, leaving each shnum in
1550             # its original location, but using a share from set1 or set2
1551             # according to the following sequence:
1552             #
1553             #  4-of-9  a  s2
1554             #  4-of-9  b  s2
1555             #  4-of-7  c   s3
1556             #  4-of-9  d  s2
1557             #  3-of-9  e s1
1558             #  3-of-9  f s1
1559             #  3-of-9  g s1
1560             #  4-of-9  h  s2
1561             #
1562             # so that neither form can be recovered until fetch [f], at which
1563             # point version-s1 (the 3-of-10 form) should be recoverable. If
1564             # the implementation latches on to the first version it sees,
1565             # then s2 will be recoverable at fetch [g].
1566
1567             # Later, when we implement code that handles multiple versions,
1568             # we can use this framework to assert that all recoverable
1569             # versions are retrieved, and test that 'epsilon' does its job
1570
1571             places = [2, 2, 3, 2, 1, 1, 1, 2]
1572
1573             sharemap = {}
1574
1575             for i,peerid in enumerate(self._client.get_all_serverids()):
1576                 peerid_s = shortnodeid_b2a(peerid)
1577                 for shnum in self._shares1.get(peerid, {}):
1578                     if shnum < len(places):
1579                         which = places[shnum]
1580                     else:
1581                         which = "x"
1582                     self._client._storage._peers[peerid] = peers = {}
1583                     in_1 = shnum in self._shares1[peerid]
1584                     in_2 = shnum in self._shares2.get(peerid, {})
1585                     in_3 = shnum in self._shares3.get(peerid, {})
1586                     #print peerid_s, shnum, which, in_1, in_2, in_3
1587                     if which == 1:
1588                         if in_1:
1589                             peers[shnum] = self._shares1[peerid][shnum]
1590                             sharemap[shnum] = peerid
1591                     elif which == 2:
1592                         if in_2:
1593                             peers[shnum] = self._shares2[peerid][shnum]
1594                             sharemap[shnum] = peerid
1595                     elif which == 3:
1596                         if in_3:
1597                             peers[shnum] = self._shares3[peerid][shnum]
1598                             sharemap[shnum] = peerid
1599
1600             # we don't bother placing any other shares
1601             # now sort the sequence so that share 0 is returned first
1602             new_sequence = [sharemap[shnum]
1603                             for shnum in sorted(sharemap.keys())]
1604             self._client._storage._sequence = new_sequence
1605             log.msg("merge done")
1606         d.addCallback(_merge)
1607         d.addCallback(lambda res: fn3.download_best_version())
1608         def _retrieved(new_contents):
1609             # the current specified behavior is "first version recoverable"
1610             self.failUnlessEqual(new_contents, contents1)
1611         d.addCallback(_retrieved)
1612         return d
1613
1614
1615 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1616
1617     def setUp(self):
1618         return self.publish_multiple()
1619
1620     def test_multiple_versions(self):
1621         # if we see a mix of versions in the grid, download_best_version
1622         # should get the latest one
1623         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1624         d = self._fn.download_best_version()
1625         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1626         # and the checker should report problems
1627         d.addCallback(lambda res: self._fn.check(Monitor()))
1628         d.addCallback(self.check_bad, "test_multiple_versions")
1629
1630         # but if everything is at version 2, that's what we should download
1631         d.addCallback(lambda res:
1632                       self._set_versions(dict([(i,2) for i in range(10)])))
1633         d.addCallback(lambda res: self._fn.download_best_version())
1634         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1635         # if exactly one share is at version 3, we should still get v2
1636         d.addCallback(lambda res:
1637                       self._set_versions({0:3}))
1638         d.addCallback(lambda res: self._fn.download_best_version())
1639         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1640         # but the servermap should see the unrecoverable version. This
1641         # depends upon the single newer share being queried early.
1642         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1643         def _check_smap(smap):
1644             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1645             newer = smap.unrecoverable_newer_versions()
1646             self.failUnlessEqual(len(newer), 1)
1647             verinfo, health = newer.items()[0]
1648             self.failUnlessEqual(verinfo[0], 4)
1649             self.failUnlessEqual(health, (1,3))
1650             self.failIf(smap.needs_merge())
1651         d.addCallback(_check_smap)
1652         # if we have a mix of two parallel versions (s4a and s4b), we could
1653         # recover either
1654         d.addCallback(lambda res:
1655                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1656                                           1:4,3:4,5:4,7:4,9:4}))
1657         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1658         def _check_smap_mixed(smap):
1659             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1660             newer = smap.unrecoverable_newer_versions()
1661             self.failUnlessEqual(len(newer), 0)
1662             self.failUnless(smap.needs_merge())
1663         d.addCallback(_check_smap_mixed)
1664         d.addCallback(lambda res: self._fn.download_best_version())
1665         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1666                                                   res == self.CONTENTS[4]))
1667         return d
1668
1669     def test_replace(self):
1670         # if we see a mix of versions in the grid, we should be able to
1671         # replace them all with a newer version
1672
1673         # if exactly one share is at version 3, we should download (and
1674         # replace) v2, and the result should be v4. Note that the index we
1675         # give to _set_versions is different than the sequence number.
1676         target = dict([(i,2) for i in range(10)]) # seqnum3
1677         target[0] = 3 # seqnum4
1678         self._set_versions(target)
1679
1680         def _modify(oldversion, servermap, first_time):
1681             return oldversion + " modified"
1682         d = self._fn.modify(_modify)
1683         d.addCallback(lambda res: self._fn.download_best_version())
1684         expected = self.CONTENTS[2] + " modified"
1685         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1686         # and the servermap should indicate that the outlier was replaced too
1687         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1688         def _check_smap(smap):
1689             self.failUnlessEqual(smap.highest_seqnum(), 5)
1690             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1691             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1692         d.addCallback(_check_smap)
1693         return d
1694
1695
1696 class Utils(unittest.TestCase):
1697     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1698         # we compare this against sets of integers
1699         x = set(range(x_start, x_start+x_length))
1700         y = set(range(y_start, y_start+y_length))
1701         should_be_inside = x.issubset(y)
1702         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1703                                                          y_start, y_length),
1704                              str((x_start, x_length, y_start, y_length)))
1705
1706     def test_cache_inside(self):
1707         c = ResponseCache()
1708         x_start = 10
1709         x_length = 5
1710         for y_start in range(8, 17):
1711             for y_length in range(8):
1712                 self._do_inside(c, x_start, x_length, y_start, y_length)
1713
1714     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1715         # we compare this against sets of integers
1716         x = set(range(x_start, x_start+x_length))
1717         y = set(range(y_start, y_start+y_length))
1718         overlap = bool(x.intersection(y))
1719         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1720                                                       y_start, y_length),
1721                              str((x_start, x_length, y_start, y_length)))
1722
1723     def test_cache_overlap(self):
1724         c = ResponseCache()
1725         x_start = 10
1726         x_length = 5
1727         for y_start in range(8, 17):
1728             for y_length in range(8):
1729                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1730
1731     def test_cache(self):
1732         c = ResponseCache()
1733         # xdata = base62.b2a(os.urandom(100))[:100]
1734         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1735         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1736         nope = (None, None)
1737         c.add("v1", 1, 0, xdata, "time0")
1738         c.add("v1", 1, 2000, ydata, "time1")
1739         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1740         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1741         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1742         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1743         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1744         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1745         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1747         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1748         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1749         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1750         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1752         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1753         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1754         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1755         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1756         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1757
1758         # optional: join fragments
1759         c = ResponseCache()
1760         c.add("v1", 1, 0, xdata[:10], "time0")
1761         c.add("v1", 1, 10, xdata[10:20], "time1")
1762         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1763
1764 class Exceptions(unittest.TestCase):
1765     def test_repr(self):
1766         nmde = NeedMoreDataError(100, 50, 100)
1767         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1768         ucwe = UncoordinatedWriteError()
1769         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1770
1771 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1772 # instances which always succeed. So we need a less-fake one.
1773
1774 class IntentionalError(Exception):
1775     pass
1776
1777 class LocalWrapper:
1778     def __init__(self, original):
1779         self.original = original
1780         self.broken = False
1781         self.post_call_notifier = None
1782     def callRemote(self, methname, *args, **kwargs):
1783         def _call():
1784             if self.broken:
1785                 raise IntentionalError("I was asked to break")
1786             meth = getattr(self.original, "remote_" + methname)
1787             return meth(*args, **kwargs)
1788         d = fireEventually()
1789         d.addCallback(lambda res: _call())
1790         if self.post_call_notifier:
1791             d.addCallback(self.post_call_notifier, methname)
1792         return d
1793
1794 class LessFakeClient(FakeClient):
1795
1796     def __init__(self, basedir, num_peers=10):
1797         self._num_peers = num_peers
1798         peerids = [tagged_hash("peerid", "%d" % i)[:20] 
1799                    for i in range(self._num_peers)]
1800         self.storage_broker = StorageFarmBroker()
1801         for peerid in peerids:
1802             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1803             make_dirs(peerdir)
1804             ss = StorageServer(peerdir, peerid)
1805             lw = LocalWrapper(ss)
1806             self.storage_broker.add_server(peerid, lw)
1807         self.nodeid = "fakenodeid"
1808
1809
1810 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1811     def test_publish_surprise(self):
1812         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1813         self.client = LessFakeClient(basedir)
1814         d = self.client.create_mutable_file("contents 1")
1815         def _created(n):
1816             d = defer.succeed(None)
1817             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1818             def _got_smap1(smap):
1819                 # stash the old state of the file
1820                 self.old_map = smap
1821             d.addCallback(_got_smap1)
1822             # then modify the file, leaving the old map untouched
1823             d.addCallback(lambda res: log.msg("starting winning write"))
1824             d.addCallback(lambda res: n.overwrite("contents 2"))
1825             # now attempt to modify the file with the old servermap. This
1826             # will look just like an uncoordinated write, in which every
1827             # single share got updated between our mapupdate and our publish
1828             d.addCallback(lambda res: log.msg("starting doomed write"))
1829             d.addCallback(lambda res:
1830                           self.shouldFail(UncoordinatedWriteError,
1831                                           "test_publish_surprise", None,
1832                                           n.upload,
1833                                           "contents 2a", self.old_map))
1834             return d
1835         d.addCallback(_created)
1836         return d
1837
1838     def test_retrieve_surprise(self):
1839         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1840         self.client = LessFakeClient(basedir)
1841         d = self.client.create_mutable_file("contents 1")
1842         def _created(n):
1843             d = defer.succeed(None)
1844             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1845             def _got_smap1(smap):
1846                 # stash the old state of the file
1847                 self.old_map = smap
1848             d.addCallback(_got_smap1)
1849             # then modify the file, leaving the old map untouched
1850             d.addCallback(lambda res: log.msg("starting winning write"))
1851             d.addCallback(lambda res: n.overwrite("contents 2"))
1852             # now attempt to retrieve the old version with the old servermap.
1853             # This will look like someone has changed the file since we
1854             # updated the servermap.
1855             d.addCallback(lambda res: n._cache._clear())
1856             d.addCallback(lambda res: log.msg("starting doomed read"))
1857             d.addCallback(lambda res:
1858                           self.shouldFail(NotEnoughSharesError,
1859                                           "test_retrieve_surprise",
1860                                           "ran out of peers: have 0 shares (k=3)",
1861                                           n.download_version,
1862                                           self.old_map,
1863                                           self.old_map.best_recoverable_version(),
1864                                           ))
1865             return d
1866         d.addCallback(_created)
1867         return d
1868
1869     def test_unexpected_shares(self):
1870         # upload the file, take a servermap, shut down one of the servers,
1871         # upload it again (causing shares to appear on a new server), then
1872         # upload using the old servermap. The last upload should fail with an
1873         # UncoordinatedWriteError, because of the shares that didn't appear
1874         # in the servermap.
1875         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1876         self.client = LessFakeClient(basedir)
1877         d = self.client.create_mutable_file("contents 1")
1878         def _created(n):
1879             d = defer.succeed(None)
1880             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1881             def _got_smap1(smap):
1882                 # stash the old state of the file
1883                 self.old_map = smap
1884                 # now shut down one of the servers
1885                 peer0 = list(smap.make_sharemap()[0])[0]
1886                 self.client.debug_remove_connection(peer0)
1887                 # then modify the file, leaving the old map untouched
1888                 log.msg("starting winning write")
1889                 return n.overwrite("contents 2")
1890             d.addCallback(_got_smap1)
1891             # now attempt to modify the file with the old servermap. This
1892             # will look just like an uncoordinated write, in which every
1893             # single share got updated between our mapupdate and our publish
1894             d.addCallback(lambda res: log.msg("starting doomed write"))
1895             d.addCallback(lambda res:
1896                           self.shouldFail(UncoordinatedWriteError,
1897                                           "test_surprise", None,
1898                                           n.upload,
1899                                           "contents 2a", self.old_map))
1900             return d
1901         d.addCallback(_created)
1902         return d
1903
1904     def test_bad_server(self):
1905         # Break one server, then create the file: the initial publish should
1906         # complete with an alternate server. Breaking a second server should
1907         # not prevent an update from succeeding either.
1908         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1909         self.client = LessFakeClient(basedir, 20)
1910         # to make sure that one of the initial peers is broken, we have to
1911         # get creative. We create the keys, so we can figure out the storage
1912         # index, but we hold off on doing the initial publish until we've
1913         # broken the server on which the first share wants to be stored.
1914         n = FastMutableFileNode(self.client)
1915         d = defer.succeed(None)
1916         d.addCallback(n._generate_pubprivkeys)
1917         d.addCallback(n._generated)
1918         def _break_peer0(res):
1919             si = n.get_storage_index()
1920             peerlist = list(self.client.storage_broker.get_servers(si))
1921             peerid0, connection0 = peerlist[0]
1922             peerid1, connection1 = peerlist[1]
1923             connection0.broken = True
1924             self.connection1 = connection1
1925         d.addCallback(_break_peer0)
1926         # now let the initial publish finally happen
1927         d.addCallback(lambda res: n._upload("contents 1", None))
1928         # that ought to work
1929         d.addCallback(lambda res: n.download_best_version())
1930         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1931         # now break the second peer
1932         def _break_peer1(res):
1933             self.connection1.broken = True
1934         d.addCallback(_break_peer1)
1935         d.addCallback(lambda res: n.overwrite("contents 2"))
1936         # that ought to work too
1937         d.addCallback(lambda res: n.download_best_version())
1938         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1939         def _explain_error(f):
1940             print f
1941             if f.check(NotEnoughServersError):
1942                 print "first_error:", f.value.first_error
1943             return f
1944         d.addErrback(_explain_error)
1945         return d
1946
1947     def test_bad_server_overlap(self):
1948         # like test_bad_server, but with no extra unused servers to fall back
1949         # upon. This means that we must re-use a server which we've already
1950         # used. If we don't remember the fact that we sent them one share
1951         # already, we'll mistakenly think we're experiencing an
1952         # UncoordinatedWriteError.
1953
1954         # Break one server, then create the file: the initial publish should
1955         # complete with an alternate server. Breaking a second server should
1956         # not prevent an update from succeeding either.
1957         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1958         self.client = LessFakeClient(basedir, 10)
1959
1960         peerids = list(self.client.get_all_serverids())
1961         self.client.debug_break_connection(peerids[0])
1962
1963         d = self.client.create_mutable_file("contents 1")
1964         def _created(n):
1965             d = n.download_best_version()
1966             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1967             # now break one of the remaining servers
1968             def _break_second_server(res):
1969                 self.client.debug_break_connection(peerids[1])
1970             d.addCallback(_break_second_server)
1971             d.addCallback(lambda res: n.overwrite("contents 2"))
1972             # that ought to work too
1973             d.addCallback(lambda res: n.download_best_version())
1974             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1975             return d
1976         d.addCallback(_created)
1977         return d
1978
1979     def test_publish_all_servers_bad(self):
1980         # Break all servers: the publish should fail
1981         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1982         self.client = LessFakeClient(basedir, 20)
1983         for peerid in self.client.get_all_serverids():
1984             self.client.debug_break_connection(peerid)
1985         d = self.shouldFail(NotEnoughServersError,
1986                             "test_publish_all_servers_bad",
1987                             "Ran out of non-bad servers",
1988                             self.client.create_mutable_file, "contents")
1989         return d
1990
1991     def test_publish_no_servers(self):
1992         # no servers at all: the publish should fail
1993         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1994         self.client = LessFakeClient(basedir, 0)
1995         d = self.shouldFail(NotEnoughServersError,
1996                             "test_publish_no_servers",
1997                             "Ran out of non-bad servers",
1998                             self.client.create_mutable_file, "contents")
1999         return d
2000     test_publish_no_servers.timeout = 30
2001
2002
2003     def test_privkey_query_error(self):
2004         # when a servermap is updated with MODE_WRITE, it tries to get the
2005         # privkey. Something might go wrong during this query attempt.
2006         self.client = FakeClient(20)
2007         # we need some contents that are large enough to push the privkey out
2008         # of the early part of the file
2009         LARGE = "These are Larger contents" * 200 # about 5KB
2010         d = self.client.create_mutable_file(LARGE)
2011         def _created(n):
2012             self.uri = n.get_uri()
2013             self.n2 = self.client.create_node_from_uri(self.uri)
2014             # we start by doing a map update to figure out which is the first
2015             # server.
2016             return n.get_servermap(MODE_WRITE)
2017         d.addCallback(_created)
2018         d.addCallback(lambda res: fireEventually(res))
2019         def _got_smap1(smap):
2020             peer0 = list(smap.make_sharemap()[0])[0]
2021             # we tell the server to respond to this peer first, so that it
2022             # will be asked for the privkey first
2023             self.client._storage._sequence = [peer0]
2024             # now we make the peer fail their second query
2025             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2026         d.addCallback(_got_smap1)
2027         # now we update a servermap from a new node (which doesn't have the
2028         # privkey yet, forcing it to use a separate privkey query). Each
2029         # query response will trigger a privkey query, and since we're using
2030         # _sequence to make the peer0 response come back first, we'll send it
2031         # a privkey query first, and _sequence will again ensure that the
2032         # peer0 query will also come back before the others, and then
2033         # _special_answers will make sure that the query raises an exception.
2034         # The whole point of these hijinks is to exercise the code in
2035         # _privkey_query_failed. Note that the map-update will succeed, since
2036         # we'll just get a copy from one of the other shares.
2037         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2038         # Using FakeStorage._sequence means there will be read requests still
2039         # floating around.. wait for them to retire
2040         def _cancel_timer(res):
2041             if self.client._storage._pending_timer:
2042                 self.client._storage._pending_timer.cancel()
2043             return res
2044         d.addBoth(_cancel_timer)
2045         return d
2046
2047     def test_privkey_query_missing(self):
2048         # like test_privkey_query_error, but the shares are deleted by the
2049         # second query, instead of raising an exception.
2050         self.client = FakeClient(20)
2051         LARGE = "These are Larger contents" * 200 # about 5KB
2052         d = self.client.create_mutable_file(LARGE)
2053         def _created(n):
2054             self.uri = n.get_uri()
2055             self.n2 = self.client.create_node_from_uri(self.uri)
2056             return n.get_servermap(MODE_WRITE)
2057         d.addCallback(_created)
2058         d.addCallback(lambda res: fireEventually(res))
2059         def _got_smap1(smap):
2060             peer0 = list(smap.make_sharemap()[0])[0]
2061             self.client._storage._sequence = [peer0]
2062             self.client._storage._special_answers[peer0] = ["normal", "none"]
2063         d.addCallback(_got_smap1)
2064         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2065         def _cancel_timer(res):
2066             if self.client._storage._pending_timer:
2067                 self.client._storage._pending_timer.cancel()
2068             return res
2069         d.addBoth(_cancel_timer)
2070         return d