]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
Allow tests to pass with -OO by turning some AssertionErrors (the ones that
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri
8 from allmydata.storage.server import StorageServer
9 from allmydata.immutable import download
10 from allmydata.util import base32, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      NotEnoughSharesError, IRepairResults, ICheckAndRepairResults
16 from allmydata.monitor import Monitor
17 from allmydata.test.common import ShouldFailMixin
18 from foolscap.api import eventually, fireEventually
19 from foolscap.logging import log
20 from allmydata.storage_client import StorageFarmBroker
21
22 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
23 from allmydata.mutable.common import ResponseCache, \
24      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
25      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
26      NotEnoughServersError, CorruptShareError
27 from allmydata.mutable.retrieve import Retrieve
28 from allmydata.mutable.publish import Publish
29 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
30 from allmydata.mutable.layout import unpack_header, unpack_share
31 from allmydata.mutable.repairer import MustForceRepairError
32
33 import common_util as testutil
34
35 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
36 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
37 # of this.
38
39 class FastMutableFileNode(MutableFileNode):
40     SIGNATURE_KEY_SIZE = 522
41
42 # this "FakeStorage" exists to put the share data in RAM and avoid using real
43 # network connections, both to speed up the tests and to reduce the amount of
44 # non-mutable.py code being exercised.
45
46 class FakeStorage:
47     # this class replaces the collection of storage servers, allowing the
48     # tests to examine and manipulate the published shares. It also lets us
49     # control the order in which read queries are answered, to exercise more
50     # of the error-handling code in Retrieve .
51     #
52     # Note that we ignore the storage index: this FakeStorage instance can
53     # only be used for a single storage index.
54
55
56     def __init__(self):
57         self._peers = {}
58         # _sequence is used to cause the responses to occur in a specific
59         # order. If it is in use, then we will defer queries instead of
60         # answering them right away, accumulating the Deferreds in a dict. We
61         # don't know exactly how many queries we'll get, so exactly one
62         # second after the first query arrives, we will release them all (in
63         # order).
64         self._sequence = None
65         self._pending = {}
66         self._pending_timer = None
67         self._special_answers = {}
68
69     def read(self, peerid, storage_index):
70         shares = self._peers.get(peerid, {})
71         if self._special_answers.get(peerid, []):
72             mode = self._special_answers[peerid].pop(0)
73             if mode == "fail":
74                 shares = failure.Failure(IntentionalError())
75             elif mode == "none":
76                 shares = {}
77             elif mode == "normal":
78                 pass
79         if self._sequence is None:
80             return defer.succeed(shares)
81         d = defer.Deferred()
82         if not self._pending:
83             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
84         self._pending[peerid] = (d, shares)
85         return d
86
87     def _fire_readers(self):
88         self._pending_timer = None
89         pending = self._pending
90         self._pending = {}
91         extra = []
92         for peerid in self._sequence:
93             if peerid in pending:
94                 d, shares = pending.pop(peerid)
95                 eventually(d.callback, shares)
96         for (d, shares) in pending.values():
97             eventually(d.callback, shares)
98
99     def write(self, peerid, storage_index, shnum, offset, data):
100         if peerid not in self._peers:
101             self._peers[peerid] = {}
102         shares = self._peers[peerid]
103         f = StringIO()
104         f.write(shares.get(shnum, ""))
105         f.seek(offset)
106         f.write(data)
107         shares[shnum] = f.getvalue()
108
109
110 class FakeStorageServer:
111     def __init__(self, peerid, storage):
112         self.peerid = peerid
113         self.storage = storage
114         self.queries = 0
115     def callRemote(self, methname, *args, **kwargs):
116         def _call():
117             meth = getattr(self, methname)
118             return meth(*args, **kwargs)
119         d = fireEventually()
120         d.addCallback(lambda res: _call())
121         return d
122     def callRemoteOnly(self, methname, *args, **kwargs):
123         d = self.callRemote(methname, *args, **kwargs)
124         d.addBoth(lambda ignore: None)
125         pass
126
127     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
128         pass
129
130     def slot_readv(self, storage_index, shnums, readv):
131         d = self.storage.read(self.peerid, storage_index)
132         def _read(shares):
133             response = {}
134             for shnum in shares:
135                 if shnums and shnum not in shnums:
136                     continue
137                 vector = response[shnum] = []
138                 for (offset, length) in readv:
139                     assert isinstance(offset, (int, long)), offset
140                     assert isinstance(length, (int, long)), length
141                     vector.append(shares[shnum][offset:offset+length])
142             return response
143         d.addCallback(_read)
144         return d
145
146     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
147                                         tw_vectors, read_vector):
148         # always-pass: parrot the test vectors back to them.
149         readv = {}
150         for shnum, (testv, writev, new_length) in tw_vectors.items():
151             for (offset, length, op, specimen) in testv:
152                 assert op in ("le", "eq", "ge")
153             # TODO: this isn't right, the read is controlled by read_vector,
154             # not by testv
155             readv[shnum] = [ specimen
156                              for (offset, length, op, specimen)
157                              in testv ]
158             for (offset, data) in writev:
159                 self.storage.write(self.peerid, storage_index, shnum,
160                                    offset, data)
161         answer = (True, readv)
162         return fireEventually(answer)
163
164
165 # our "FakeClient" has just enough functionality of the real Client to let
166 # the tests run.
167
168 class FakeClient:
169     mutable_file_node_class = FastMutableFileNode
170
171     def __init__(self, num_peers=10):
172         self._storage = FakeStorage()
173         self._num_peers = num_peers
174         peerids = [tagged_hash("peerid", "%d" % i)[:20]
175                    for i in range(self._num_peers)]
176         self.nodeid = "fakenodeid"
177         self.storage_broker = StorageFarmBroker(None, True)
178         for peerid in peerids:
179             fss = FakeStorageServer(peerid, self._storage)
180             self.storage_broker.test_add_server(peerid, fss)
181
182     def get_storage_broker(self):
183         return self.storage_broker
184     def debug_break_connection(self, peerid):
185         self.storage_broker.test_servers[peerid].broken = True
186     def debug_remove_connection(self, peerid):
187         self.storage_broker.test_servers.pop(peerid)
188     def debug_get_connection(self, peerid):
189         return self.storage_broker.test_servers[peerid]
190
191     def get_encoding_parameters(self):
192         return {"k": 3, "n": 10}
193
194     def log(self, msg, **kw):
195         return log.msg(msg, **kw)
196
197     def get_renewal_secret(self):
198         return "I hereby permit you to renew my files"
199     def get_cancel_secret(self):
200         return "I hereby permit you to cancel my leases"
201
202     def create_mutable_file(self, contents=""):
203         n = self.mutable_file_node_class(self)
204         d = n.create(contents)
205         d.addCallback(lambda res: n)
206         return d
207
208     def get_history(self):
209         return None
210
211     def create_node_from_uri(self, u, readcap=None):
212         if not u:
213             u = readcap
214         u = IURI(u)
215         assert IMutableFileURI.providedBy(u), u
216         res = self.mutable_file_node_class(self).init_from_uri(u)
217         return res
218
219     def upload(self, uploadable):
220         assert IUploadable.providedBy(uploadable)
221         d = uploadable.get_size()
222         d.addCallback(lambda length: uploadable.read(length))
223         #d.addCallback(self.create_mutable_file)
224         def _got_data(datav):
225             data = "".join(datav)
226             #newnode = FastMutableFileNode(self)
227             return uri.LiteralFileURI(data)
228         d.addCallback(_got_data)
229         return d
230
231
232 def flip_bit(original, byte_offset):
233     return (original[:byte_offset] +
234             chr(ord(original[byte_offset]) ^ 0x01) +
235             original[byte_offset+1:])
236
237 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
238     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
239     # list of shnums to corrupt.
240     for peerid in s._peers:
241         shares = s._peers[peerid]
242         for shnum in shares:
243             if (shnums_to_corrupt is not None
244                 and shnum not in shnums_to_corrupt):
245                 continue
246             data = shares[shnum]
247             (version,
248              seqnum,
249              root_hash,
250              IV,
251              k, N, segsize, datalen,
252              o) = unpack_header(data)
253             if isinstance(offset, tuple):
254                 offset1, offset2 = offset
255             else:
256                 offset1 = offset
257                 offset2 = 0
258             if offset1 == "pubkey":
259                 real_offset = 107
260             elif offset1 in o:
261                 real_offset = o[offset1]
262             else:
263                 real_offset = offset1
264             real_offset = int(real_offset) + offset2 + offset_offset
265             assert isinstance(real_offset, int), offset
266             shares[shnum] = flip_bit(data, real_offset)
267     return res
268
269 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
270     # this used to be in Publish, but we removed the limit. Some of
271     # these tests test whether the new code correctly allows files
272     # larger than the limit.
273     OLD_MAX_SEGMENT_SIZE = 3500000
274     def setUp(self):
275         self.client = FakeClient()
276
277     def test_create(self):
278         d = self.client.create_mutable_file()
279         def _created(n):
280             self.failUnless(isinstance(n, FastMutableFileNode))
281             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
282             sb = self.client.get_storage_broker()
283             peer0 = sorted(sb.get_all_serverids())[0]
284             shnums = self.client._storage._peers[peer0].keys()
285             self.failUnlessEqual(len(shnums), 1)
286         d.addCallback(_created)
287         return d
288
289     def test_serialize(self):
290         n = MutableFileNode(self.client)
291         calls = []
292         def _callback(*args, **kwargs):
293             self.failUnlessEqual(args, (4,) )
294             self.failUnlessEqual(kwargs, {"foo": 5})
295             calls.append(1)
296             return 6
297         d = n._do_serialized(_callback, 4, foo=5)
298         def _check_callback(res):
299             self.failUnlessEqual(res, 6)
300             self.failUnlessEqual(calls, [1])
301         d.addCallback(_check_callback)
302
303         def _errback():
304             raise ValueError("heya")
305         d.addCallback(lambda res:
306                       self.shouldFail(ValueError, "_check_errback", "heya",
307                                       n._do_serialized, _errback))
308         return d
309
310     def test_upload_and_download(self):
311         d = self.client.create_mutable_file()
312         def _created(n):
313             d = defer.succeed(None)
314             d.addCallback(lambda res: n.get_servermap(MODE_READ))
315             d.addCallback(lambda smap: smap.dump(StringIO()))
316             d.addCallback(lambda sio:
317                           self.failUnless("3-of-10" in sio.getvalue()))
318             d.addCallback(lambda res: n.overwrite("contents 1"))
319             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
320             d.addCallback(lambda res: n.download_best_version())
321             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
322             d.addCallback(lambda res: n.get_size_of_best_version())
323             d.addCallback(lambda size:
324                           self.failUnlessEqual(size, len("contents 1")))
325             d.addCallback(lambda res: n.overwrite("contents 2"))
326             d.addCallback(lambda res: n.download_best_version())
327             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
328             d.addCallback(lambda res: n.download(download.Data()))
329             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
330             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
331             d.addCallback(lambda smap: n.upload("contents 3", smap))
332             d.addCallback(lambda res: n.download_best_version())
333             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
334             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
335             d.addCallback(lambda smap:
336                           n.download_version(smap,
337                                              smap.best_recoverable_version()))
338             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
339             # test a file that is large enough to overcome the
340             # mapupdate-to-retrieve data caching (i.e. make the shares larger
341             # than the default readsize, which is 2000 bytes). A 15kB file
342             # will have 5kB shares.
343             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
344             d.addCallback(lambda res: n.download_best_version())
345             d.addCallback(lambda res:
346                           self.failUnlessEqual(res, "large size file" * 1000))
347             return d
348         d.addCallback(_created)
349         return d
350
351     def test_create_with_initial_contents(self):
352         d = self.client.create_mutable_file("contents 1")
353         def _created(n):
354             d = n.download_best_version()
355             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
356             d.addCallback(lambda res: n.overwrite("contents 2"))
357             d.addCallback(lambda res: n.download_best_version())
358             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
359             return d
360         d.addCallback(_created)
361         return d
362
363     def test_create_with_too_large_contents(self):
364         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
365         d = self.client.create_mutable_file(BIG)
366         def _created(n):
367             d = n.overwrite(BIG)
368             return d
369         d.addCallback(_created)
370         return d
371
372     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
373         d = n.get_servermap(MODE_READ)
374         d.addCallback(lambda servermap: servermap.best_recoverable_version())
375         d.addCallback(lambda verinfo:
376                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
377         return d
378
379     def test_modify(self):
380         def _modifier(old_contents, servermap, first_time):
381             return old_contents + "line2"
382         def _non_modifier(old_contents, servermap, first_time):
383             return old_contents
384         def _none_modifier(old_contents, servermap, first_time):
385             return None
386         def _error_modifier(old_contents, servermap, first_time):
387             raise ValueError("oops")
388         def _toobig_modifier(old_contents, servermap, first_time):
389             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
390         calls = []
391         def _ucw_error_modifier(old_contents, servermap, first_time):
392             # simulate an UncoordinatedWriteError once
393             calls.append(1)
394             if len(calls) <= 1:
395                 raise UncoordinatedWriteError("simulated")
396             return old_contents + "line3"
397         def _ucw_error_non_modifier(old_contents, servermap, first_time):
398             # simulate an UncoordinatedWriteError once, and don't actually
399             # modify the contents on subsequent invocations
400             calls.append(1)
401             if len(calls) <= 1:
402                 raise UncoordinatedWriteError("simulated")
403             return old_contents
404
405         d = self.client.create_mutable_file("line1")
406         def _created(n):
407             d = n.modify(_modifier)
408             d.addCallback(lambda res: n.download_best_version())
409             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
410             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
411
412             d.addCallback(lambda res: n.modify(_non_modifier))
413             d.addCallback(lambda res: n.download_best_version())
414             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
415             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
416
417             d.addCallback(lambda res: n.modify(_none_modifier))
418             d.addCallback(lambda res: n.download_best_version())
419             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
420             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
421
422             d.addCallback(lambda res:
423                           self.shouldFail(ValueError, "error_modifier", None,
424                                           n.modify, _error_modifier))
425             d.addCallback(lambda res: n.download_best_version())
426             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
427             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
428
429
430             d.addCallback(lambda res: n.download_best_version())
431             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
432             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
433
434             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
435             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
436             d.addCallback(lambda res: n.download_best_version())
437             d.addCallback(lambda res: self.failUnlessEqual(res,
438                                                            "line1line2line3"))
439             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
440
441             def _reset_ucw_error_modifier(res):
442                 calls[:] = []
443                 return res
444             d.addCallback(_reset_ucw_error_modifier)
445
446             # in practice, this n.modify call should publish twice: the first
447             # one gets a UCWE, the second does not. But our test jig (in
448             # which the modifier raises the UCWE) skips over the first one,
449             # so in this test there will be only one publish, and the seqnum
450             # will only be one larger than the previous test, not two (i.e. 4
451             # instead of 5).
452             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
453             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
454             d.addCallback(lambda res: n.download_best_version())
455             d.addCallback(lambda res: self.failUnlessEqual(res,
456                                                            "line1line2line3"))
457             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
458             d.addCallback(lambda res: n.modify(_toobig_modifier))
459             return d
460         d.addCallback(_created)
461         return d
462
463     def test_modify_backoffer(self):
464         def _modifier(old_contents, servermap, first_time):
465             return old_contents + "line2"
466         calls = []
467         def _ucw_error_modifier(old_contents, servermap, first_time):
468             # simulate an UncoordinatedWriteError once
469             calls.append(1)
470             if len(calls) <= 1:
471                 raise UncoordinatedWriteError("simulated")
472             return old_contents + "line3"
473         def _always_ucw_error_modifier(old_contents, servermap, first_time):
474             raise UncoordinatedWriteError("simulated")
475         def _backoff_stopper(node, f):
476             return f
477         def _backoff_pauser(node, f):
478             d = defer.Deferred()
479             reactor.callLater(0.5, d.callback, None)
480             return d
481
482         # the give-up-er will hit its maximum retry count quickly
483         giveuper = BackoffAgent()
484         giveuper._delay = 0.1
485         giveuper.factor = 1
486
487         d = self.client.create_mutable_file("line1")
488         def _created(n):
489             d = n.modify(_modifier)
490             d.addCallback(lambda res: n.download_best_version())
491             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
492             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
493
494             d.addCallback(lambda res:
495                           self.shouldFail(UncoordinatedWriteError,
496                                           "_backoff_stopper", None,
497                                           n.modify, _ucw_error_modifier,
498                                           _backoff_stopper))
499             d.addCallback(lambda res: n.download_best_version())
500             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
501             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
502
503             def _reset_ucw_error_modifier(res):
504                 calls[:] = []
505                 return res
506             d.addCallback(_reset_ucw_error_modifier)
507             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
508                                                _backoff_pauser))
509             d.addCallback(lambda res: n.download_best_version())
510             d.addCallback(lambda res: self.failUnlessEqual(res,
511                                                            "line1line2line3"))
512             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
513
514             d.addCallback(lambda res:
515                           self.shouldFail(UncoordinatedWriteError,
516                                           "giveuper", None,
517                                           n.modify, _always_ucw_error_modifier,
518                                           giveuper.delay))
519             d.addCallback(lambda res: n.download_best_version())
520             d.addCallback(lambda res: self.failUnlessEqual(res,
521                                                            "line1line2line3"))
522             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
523
524             return d
525         d.addCallback(_created)
526         return d
527
528     def test_upload_and_download_full_size_keys(self):
529         self.client.mutable_file_node_class = MutableFileNode
530         d = self.client.create_mutable_file()
531         def _created(n):
532             d = defer.succeed(None)
533             d.addCallback(lambda res: n.get_servermap(MODE_READ))
534             d.addCallback(lambda smap: smap.dump(StringIO()))
535             d.addCallback(lambda sio:
536                           self.failUnless("3-of-10" in sio.getvalue()))
537             d.addCallback(lambda res: n.overwrite("contents 1"))
538             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
539             d.addCallback(lambda res: n.download_best_version())
540             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
541             d.addCallback(lambda res: n.overwrite("contents 2"))
542             d.addCallback(lambda res: n.download_best_version())
543             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
544             d.addCallback(lambda res: n.download(download.Data()))
545             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
546             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
547             d.addCallback(lambda smap: n.upload("contents 3", smap))
548             d.addCallback(lambda res: n.download_best_version())
549             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
550             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
551             d.addCallback(lambda smap:
552                           n.download_version(smap,
553                                              smap.best_recoverable_version()))
554             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
555             return d
556         d.addCallback(_created)
557         return d
558
559
560 class MakeShares(unittest.TestCase):
561     def test_encrypt(self):
562         c = FakeClient()
563         fn = FastMutableFileNode(c)
564         CONTENTS = "some initial contents"
565         d = fn.create(CONTENTS)
566         def _created(res):
567             p = Publish(fn, None)
568             p.salt = "SALT" * 4
569             p.readkey = "\x00" * 16
570             p.newdata = CONTENTS
571             p.required_shares = 3
572             p.total_shares = 10
573             p.setup_encoding_parameters()
574             return p._encrypt_and_encode()
575         d.addCallback(_created)
576         def _done(shares_and_shareids):
577             (shares, share_ids) = shares_and_shareids
578             self.failUnlessEqual(len(shares), 10)
579             for sh in shares:
580                 self.failUnless(isinstance(sh, str))
581                 self.failUnlessEqual(len(sh), 7)
582             self.failUnlessEqual(len(share_ids), 10)
583         d.addCallback(_done)
584         return d
585
586     def test_generate(self):
587         c = FakeClient()
588         fn = FastMutableFileNode(c)
589         CONTENTS = "some initial contents"
590         d = fn.create(CONTENTS)
591         def _created(res):
592             p = Publish(fn, None)
593             self._p = p
594             p.newdata = CONTENTS
595             p.required_shares = 3
596             p.total_shares = 10
597             p.setup_encoding_parameters()
598             p._new_seqnum = 3
599             p.salt = "SALT" * 4
600             # make some fake shares
601             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
602             p._privkey = fn.get_privkey()
603             p._encprivkey = fn.get_encprivkey()
604             p._pubkey = fn.get_pubkey()
605             return p._generate_shares(shares_and_ids)
606         d.addCallback(_created)
607         def _generated(res):
608             p = self._p
609             final_shares = p.shares
610             root_hash = p.root_hash
611             self.failUnlessEqual(len(root_hash), 32)
612             self.failUnless(isinstance(final_shares, dict))
613             self.failUnlessEqual(len(final_shares), 10)
614             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
615             for i,sh in final_shares.items():
616                 self.failUnless(isinstance(sh, str))
617                 # feed the share through the unpacker as a sanity-check
618                 pieces = unpack_share(sh)
619                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
620                  pubkey, signature, share_hash_chain, block_hash_tree,
621                  share_data, enc_privkey) = pieces
622                 self.failUnlessEqual(u_seqnum, 3)
623                 self.failUnlessEqual(u_root_hash, root_hash)
624                 self.failUnlessEqual(k, 3)
625                 self.failUnlessEqual(N, 10)
626                 self.failUnlessEqual(segsize, 21)
627                 self.failUnlessEqual(datalen, len(CONTENTS))
628                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
629                 sig_material = struct.pack(">BQ32s16s BBQQ",
630                                            0, p._new_seqnum, root_hash, IV,
631                                            k, N, segsize, datalen)
632                 self.failUnless(p._pubkey.verify(sig_material, signature))
633                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
634                 self.failUnless(isinstance(share_hash_chain, dict))
635                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
636                 for shnum,share_hash in share_hash_chain.items():
637                     self.failUnless(isinstance(shnum, int))
638                     self.failUnless(isinstance(share_hash, str))
639                     self.failUnlessEqual(len(share_hash), 32)
640                 self.failUnless(isinstance(block_hash_tree, list))
641                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
642                 self.failUnlessEqual(IV, "SALT"*4)
643                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
644                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
645         d.addCallback(_generated)
646         return d
647
648     # TODO: when we publish to 20 peers, we should get one share per peer on 10
649     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
650     # when we publish to zero peers, we should get a NotEnoughSharesError
651
652 class PublishMixin:
653     def publish_one(self):
654         # publish a file and create shares, which can then be manipulated
655         # later.
656         self.CONTENTS = "New contents go here" * 1000
657         num_peers = 20
658         self._client = FakeClient(num_peers)
659         self._storage = self._client._storage
660         d = self._client.create_mutable_file(self.CONTENTS)
661         def _created(node):
662             self._fn = node
663             self._fn2 = self._client.create_node_from_uri(node.get_uri())
664         d.addCallback(_created)
665         return d
666     def publish_multiple(self):
667         self.CONTENTS = ["Contents 0",
668                          "Contents 1",
669                          "Contents 2",
670                          "Contents 3a",
671                          "Contents 3b"]
672         self._copied_shares = {}
673         num_peers = 20
674         self._client = FakeClient(num_peers)
675         self._storage = self._client._storage
676         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
677         def _created(node):
678             self._fn = node
679             # now create multiple versions of the same file, and accumulate
680             # their shares, so we can mix and match them later.
681             d = defer.succeed(None)
682             d.addCallback(self._copy_shares, 0)
683             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
684             d.addCallback(self._copy_shares, 1)
685             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
686             d.addCallback(self._copy_shares, 2)
687             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
688             d.addCallback(self._copy_shares, 3)
689             # now we replace all the shares with version s3, and upload a new
690             # version to get s4b.
691             rollback = dict([(i,2) for i in range(10)])
692             d.addCallback(lambda res: self._set_versions(rollback))
693             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
694             d.addCallback(self._copy_shares, 4)
695             # we leave the storage in state 4
696             return d
697         d.addCallback(_created)
698         return d
699
700     def _copy_shares(self, ignored, index):
701         shares = self._client._storage._peers
702         # we need a deep copy
703         new_shares = {}
704         for peerid in shares:
705             new_shares[peerid] = {}
706             for shnum in shares[peerid]:
707                 new_shares[peerid][shnum] = shares[peerid][shnum]
708         self._copied_shares[index] = new_shares
709
710     def _set_versions(self, versionmap):
711         # versionmap maps shnums to which version (0,1,2,3,4) we want the
712         # share to be at. Any shnum which is left out of the map will stay at
713         # its current version.
714         shares = self._client._storage._peers
715         oldshares = self._copied_shares
716         for peerid in shares:
717             for shnum in shares[peerid]:
718                 if shnum in versionmap:
719                     index = versionmap[shnum]
720                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
721
722
723 class Servermap(unittest.TestCase, PublishMixin):
724     def setUp(self):
725         return self.publish_one()
726
727     def make_servermap(self, mode=MODE_CHECK, fn=None):
728         if fn is None:
729             fn = self._fn
730         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
731         d = smu.update()
732         return d
733
734     def update_servermap(self, oldmap, mode=MODE_CHECK):
735         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
736         d = smu.update()
737         return d
738
739     def failUnlessOneRecoverable(self, sm, num_shares):
740         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
741         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
742         best = sm.best_recoverable_version()
743         self.failIfEqual(best, None)
744         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
745         self.failUnlessEqual(len(sm.shares_available()), 1)
746         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
747         shnum, peerids = sm.make_sharemap().items()[0]
748         peerid = list(peerids)[0]
749         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
750         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
751         return sm
752
753     def test_basic(self):
754         d = defer.succeed(None)
755         ms = self.make_servermap
756         us = self.update_servermap
757
758         d.addCallback(lambda res: ms(mode=MODE_CHECK))
759         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
760         d.addCallback(lambda res: ms(mode=MODE_WRITE))
761         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
762         d.addCallback(lambda res: ms(mode=MODE_READ))
763         # this more stops at k+epsilon, and epsilon=k, so 6 shares
764         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
765         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
766         # this mode stops at 'k' shares
767         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
768
769         # and can we re-use the same servermap? Note that these are sorted in
770         # increasing order of number of servers queried, since once a server
771         # gets into the servermap, we'll always ask it for an update.
772         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
773         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
774         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
775         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
777         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
778         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
779         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
780         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
781
782         return d
783
784     def test_fetch_privkey(self):
785         d = defer.succeed(None)
786         # use the sibling filenode (which hasn't been used yet), and make
787         # sure it can fetch the privkey. The file is small, so the privkey
788         # will be fetched on the first (query) pass.
789         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
790         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
791
792         # create a new file, which is large enough to knock the privkey out
793         # of the early part of the file
794         LARGE = "These are Larger contents" * 200 # about 5KB
795         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
796         def _created(large_fn):
797             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
798             return self.make_servermap(MODE_WRITE, large_fn2)
799         d.addCallback(_created)
800         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
801         return d
802
803     def test_mark_bad(self):
804         d = defer.succeed(None)
805         ms = self.make_servermap
806         us = self.update_servermap
807
808         d.addCallback(lambda res: ms(mode=MODE_READ))
809         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
810         def _made_map(sm):
811             v = sm.best_recoverable_version()
812             vm = sm.make_versionmap()
813             shares = list(vm[v])
814             self.failUnlessEqual(len(shares), 6)
815             self._corrupted = set()
816             # mark the first 5 shares as corrupt, then update the servermap.
817             # The map should not have the marked shares it in any more, and
818             # new shares should be found to replace the missing ones.
819             for (shnum, peerid, timestamp) in shares:
820                 if shnum < 5:
821                     self._corrupted.add( (peerid, shnum) )
822                     sm.mark_bad_share(peerid, shnum, "")
823             return self.update_servermap(sm, MODE_WRITE)
824         d.addCallback(_made_map)
825         def _check_map(sm):
826             # this should find all 5 shares that weren't marked bad
827             v = sm.best_recoverable_version()
828             vm = sm.make_versionmap()
829             shares = list(vm[v])
830             for (peerid, shnum) in self._corrupted:
831                 peer_shares = sm.shares_on_peer(peerid)
832                 self.failIf(shnum in peer_shares,
833                             "%d was in %s" % (shnum, peer_shares))
834             self.failUnlessEqual(len(shares), 5)
835         d.addCallback(_check_map)
836         return d
837
838     def failUnlessNoneRecoverable(self, sm):
839         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
840         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
841         best = sm.best_recoverable_version()
842         self.failUnlessEqual(best, None)
843         self.failUnlessEqual(len(sm.shares_available()), 0)
844
845     def test_no_shares(self):
846         self._client._storage._peers = {} # delete all shares
847         ms = self.make_servermap
848         d = defer.succeed(None)
849
850         d.addCallback(lambda res: ms(mode=MODE_CHECK))
851         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
852
853         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
854         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
855
856         d.addCallback(lambda res: ms(mode=MODE_WRITE))
857         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
858
859         d.addCallback(lambda res: ms(mode=MODE_READ))
860         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
861
862         return d
863
864     def failUnlessNotQuiteEnough(self, sm):
865         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
866         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
867         best = sm.best_recoverable_version()
868         self.failUnlessEqual(best, None)
869         self.failUnlessEqual(len(sm.shares_available()), 1)
870         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
871         return sm
872
873     def test_not_quite_enough_shares(self):
874         s = self._client._storage
875         ms = self.make_servermap
876         num_shares = len(s._peers)
877         for peerid in s._peers:
878             s._peers[peerid] = {}
879             num_shares -= 1
880             if num_shares == 2:
881                 break
882         # now there ought to be only two shares left
883         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
884
885         d = defer.succeed(None)
886
887         d.addCallback(lambda res: ms(mode=MODE_CHECK))
888         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
889         d.addCallback(lambda sm:
890                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
891         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
892         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
893         d.addCallback(lambda res: ms(mode=MODE_WRITE))
894         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
895         d.addCallback(lambda res: ms(mode=MODE_READ))
896         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
897
898         return d
899
900
901
902 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
903     def setUp(self):
904         return self.publish_one()
905
906     def make_servermap(self, mode=MODE_READ, oldmap=None):
907         if oldmap is None:
908             oldmap = ServerMap()
909         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
910         d = smu.update()
911         return d
912
913     def abbrev_verinfo(self, verinfo):
914         if verinfo is None:
915             return None
916         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
917          offsets_tuple) = verinfo
918         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
919
920     def abbrev_verinfo_dict(self, verinfo_d):
921         output = {}
922         for verinfo,value in verinfo_d.items():
923             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
924              offsets_tuple) = verinfo
925             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
926         return output
927
928     def dump_servermap(self, servermap):
929         print "SERVERMAP", servermap
930         print "RECOVERABLE", [self.abbrev_verinfo(v)
931                               for v in servermap.recoverable_versions()]
932         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
933         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
934
935     def do_download(self, servermap, version=None):
936         if version is None:
937             version = servermap.best_recoverable_version()
938         r = Retrieve(self._fn, servermap, version)
939         return r.download()
940
941     def test_basic(self):
942         d = self.make_servermap()
943         def _do_retrieve(servermap):
944             self._smap = servermap
945             #self.dump_servermap(servermap)
946             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
947             return self.do_download(servermap)
948         d.addCallback(_do_retrieve)
949         def _retrieved(new_contents):
950             self.failUnlessEqual(new_contents, self.CONTENTS)
951         d.addCallback(_retrieved)
952         # we should be able to re-use the same servermap, both with and
953         # without updating it.
954         d.addCallback(lambda res: self.do_download(self._smap))
955         d.addCallback(_retrieved)
956         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
957         d.addCallback(lambda res: self.do_download(self._smap))
958         d.addCallback(_retrieved)
959         # clobbering the pubkey should make the servermap updater re-fetch it
960         def _clobber_pubkey(res):
961             self._fn._pubkey = None
962         d.addCallback(_clobber_pubkey)
963         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
964         d.addCallback(lambda res: self.do_download(self._smap))
965         d.addCallback(_retrieved)
966         return d
967
968     def test_all_shares_vanished(self):
969         d = self.make_servermap()
970         def _remove_shares(servermap):
971             for shares in self._storage._peers.values():
972                 shares.clear()
973             d1 = self.shouldFail(NotEnoughSharesError,
974                                  "test_all_shares_vanished",
975                                  "ran out of peers",
976                                  self.do_download, servermap)
977             return d1
978         d.addCallback(_remove_shares)
979         return d
980
981     def test_no_servers(self):
982         c2 = FakeClient(0)
983         self._fn._client = c2
984         # if there are no servers, then a MODE_READ servermap should come
985         # back empty
986         d = self.make_servermap()
987         def _check_servermap(servermap):
988             self.failUnlessEqual(servermap.best_recoverable_version(), None)
989             self.failIf(servermap.recoverable_versions())
990             self.failIf(servermap.unrecoverable_versions())
991             self.failIf(servermap.all_peers())
992         d.addCallback(_check_servermap)
993         return d
994     test_no_servers.timeout = 15
995
996     def test_no_servers_download(self):
997         c2 = FakeClient(0)
998         self._fn._client = c2
999         d = self.shouldFail(UnrecoverableFileError,
1000                             "test_no_servers_download",
1001                             "no recoverable versions",
1002                             self._fn.download_best_version)
1003         def _restore(res):
1004             # a failed download that occurs while we aren't connected to
1005             # anybody should not prevent a subsequent download from working.
1006             # This isn't quite the webapi-driven test that #463 wants, but it
1007             # should be close enough.
1008             self._fn._client = self._client
1009             return self._fn.download_best_version()
1010         def _retrieved(new_contents):
1011             self.failUnlessEqual(new_contents, self.CONTENTS)
1012         d.addCallback(_restore)
1013         d.addCallback(_retrieved)
1014         return d
1015     test_no_servers_download.timeout = 15
1016
1017     def _test_corrupt_all(self, offset, substring,
1018                           should_succeed=False, corrupt_early=True,
1019                           failure_checker=None):
1020         d = defer.succeed(None)
1021         if corrupt_early:
1022             d.addCallback(corrupt, self._storage, offset)
1023         d.addCallback(lambda res: self.make_servermap())
1024         if not corrupt_early:
1025             d.addCallback(corrupt, self._storage, offset)
1026         def _do_retrieve(servermap):
1027             ver = servermap.best_recoverable_version()
1028             if ver is None and not should_succeed:
1029                 # no recoverable versions == not succeeding. The problem
1030                 # should be noted in the servermap's list of problems.
1031                 if substring:
1032                     allproblems = [str(f) for f in servermap.problems]
1033                     self.failUnlessIn(substring, "".join(allproblems))
1034                 return servermap
1035             if should_succeed:
1036                 d1 = self._fn.download_version(servermap, ver)
1037                 d1.addCallback(lambda new_contents:
1038                                self.failUnlessEqual(new_contents, self.CONTENTS))
1039             else:
1040                 d1 = self.shouldFail(NotEnoughSharesError,
1041                                      "_corrupt_all(offset=%s)" % (offset,),
1042                                      substring,
1043                                      self._fn.download_version, servermap, ver)
1044             if failure_checker:
1045                 d1.addCallback(failure_checker)
1046             d1.addCallback(lambda res: servermap)
1047             return d1
1048         d.addCallback(_do_retrieve)
1049         return d
1050
1051     def test_corrupt_all_verbyte(self):
1052         # when the version byte is not 0, we hit an UnknownVersionError error
1053         # in unpack_share().
1054         d = self._test_corrupt_all(0, "UnknownVersionError")
1055         def _check_servermap(servermap):
1056             # and the dump should mention the problems
1057             s = StringIO()
1058             dump = servermap.dump(s).getvalue()
1059             self.failUnless("10 PROBLEMS" in dump, dump)
1060         d.addCallback(_check_servermap)
1061         return d
1062
1063     def test_corrupt_all_seqnum(self):
1064         # a corrupt sequence number will trigger a bad signature
1065         return self._test_corrupt_all(1, "signature is invalid")
1066
1067     def test_corrupt_all_R(self):
1068         # a corrupt root hash will trigger a bad signature
1069         return self._test_corrupt_all(9, "signature is invalid")
1070
1071     def test_corrupt_all_IV(self):
1072         # a corrupt salt/IV will trigger a bad signature
1073         return self._test_corrupt_all(41, "signature is invalid")
1074
1075     def test_corrupt_all_k(self):
1076         # a corrupt 'k' will trigger a bad signature
1077         return self._test_corrupt_all(57, "signature is invalid")
1078
1079     def test_corrupt_all_N(self):
1080         # a corrupt 'N' will trigger a bad signature
1081         return self._test_corrupt_all(58, "signature is invalid")
1082
1083     def test_corrupt_all_segsize(self):
1084         # a corrupt segsize will trigger a bad signature
1085         return self._test_corrupt_all(59, "signature is invalid")
1086
1087     def test_corrupt_all_datalen(self):
1088         # a corrupt data length will trigger a bad signature
1089         return self._test_corrupt_all(67, "signature is invalid")
1090
1091     def test_corrupt_all_pubkey(self):
1092         # a corrupt pubkey won't match the URI's fingerprint. We need to
1093         # remove the pubkey from the filenode, or else it won't bother trying
1094         # to update it.
1095         self._fn._pubkey = None
1096         return self._test_corrupt_all("pubkey",
1097                                       "pubkey doesn't match fingerprint")
1098
1099     def test_corrupt_all_sig(self):
1100         # a corrupt signature is a bad one
1101         # the signature runs from about [543:799], depending upon the length
1102         # of the pubkey
1103         return self._test_corrupt_all("signature", "signature is invalid")
1104
1105     def test_corrupt_all_share_hash_chain_number(self):
1106         # a corrupt share hash chain entry will show up as a bad hash. If we
1107         # mangle the first byte, that will look like a bad hash number,
1108         # causing an IndexError
1109         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1110
1111     def test_corrupt_all_share_hash_chain_hash(self):
1112         # a corrupt share hash chain entry will show up as a bad hash. If we
1113         # mangle a few bytes in, that will look like a bad hash.
1114         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1115
1116     def test_corrupt_all_block_hash_tree(self):
1117         return self._test_corrupt_all("block_hash_tree",
1118                                       "block hash tree failure")
1119
1120     def test_corrupt_all_block(self):
1121         return self._test_corrupt_all("share_data", "block hash tree failure")
1122
1123     def test_corrupt_all_encprivkey(self):
1124         # a corrupted privkey won't even be noticed by the reader, only by a
1125         # writer.
1126         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1127
1128
1129     def test_corrupt_all_seqnum_late(self):
1130         # corrupting the seqnum between mapupdate and retrieve should result
1131         # in NotEnoughSharesError, since each share will look invalid
1132         def _check(res):
1133             f = res[0]
1134             self.failUnless(f.check(NotEnoughSharesError))
1135             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1136         return self._test_corrupt_all(1, "ran out of peers",
1137                                       corrupt_early=False,
1138                                       failure_checker=_check)
1139
1140     def test_corrupt_all_block_hash_tree_late(self):
1141         def _check(res):
1142             f = res[0]
1143             self.failUnless(f.check(NotEnoughSharesError))
1144         return self._test_corrupt_all("block_hash_tree",
1145                                       "block hash tree failure",
1146                                       corrupt_early=False,
1147                                       failure_checker=_check)
1148
1149
1150     def test_corrupt_all_block_late(self):
1151         def _check(res):
1152             f = res[0]
1153             self.failUnless(f.check(NotEnoughSharesError))
1154         return self._test_corrupt_all("share_data", "block hash tree failure",
1155                                       corrupt_early=False,
1156                                       failure_checker=_check)
1157
1158
1159     def test_basic_pubkey_at_end(self):
1160         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1161         # download to succeed but forcing a bunch of retries first. Note that
1162         # this is rather pessimistic: our Retrieve process will throw away
1163         # the whole share if the pubkey is bad, even though the rest of the
1164         # share might be good.
1165
1166         self._fn._pubkey = None
1167         k = self._fn.get_required_shares()
1168         N = self._fn.get_total_shares()
1169         d = defer.succeed(None)
1170         d.addCallback(corrupt, self._storage, "pubkey",
1171                       shnums_to_corrupt=range(0, N-k))
1172         d.addCallback(lambda res: self.make_servermap())
1173         def _do_retrieve(servermap):
1174             self.failUnless(servermap.problems)
1175             self.failUnless("pubkey doesn't match fingerprint"
1176                             in str(servermap.problems[0]))
1177             ver = servermap.best_recoverable_version()
1178             r = Retrieve(self._fn, servermap, ver)
1179             return r.download()
1180         d.addCallback(_do_retrieve)
1181         d.addCallback(lambda new_contents:
1182                       self.failUnlessEqual(new_contents, self.CONTENTS))
1183         return d
1184
1185     def test_corrupt_some(self):
1186         # corrupt the data of first five shares (so the servermap thinks
1187         # they're good but retrieve marks them as bad), so that the
1188         # MODE_READ set of 6 will be insufficient, forcing node.download to
1189         # retry with more servers.
1190         corrupt(None, self._storage, "share_data", range(5))
1191         d = self.make_servermap()
1192         def _do_retrieve(servermap):
1193             ver = servermap.best_recoverable_version()
1194             self.failUnless(ver)
1195             return self._fn.download_best_version()
1196         d.addCallback(_do_retrieve)
1197         d.addCallback(lambda new_contents:
1198                       self.failUnlessEqual(new_contents, self.CONTENTS))
1199         return d
1200
1201     def test_download_fails(self):
1202         corrupt(None, self._storage, "signature")
1203         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1204                             "no recoverable versions",
1205                             self._fn.download_best_version)
1206         return d
1207
1208
1209 class CheckerMixin:
1210     def check_good(self, r, where):
1211         self.failUnless(r.is_healthy(), where)
1212         return r
1213
1214     def check_bad(self, r, where):
1215         self.failIf(r.is_healthy(), where)
1216         return r
1217
1218     def check_expected_failure(self, r, expected_exception, substring, where):
1219         for (peerid, storage_index, shnum, f) in r.problems:
1220             if f.check(expected_exception):
1221                 self.failUnless(substring in str(f),
1222                                 "%s: substring '%s' not in '%s'" %
1223                                 (where, substring, str(f)))
1224                 return
1225         self.fail("%s: didn't see expected exception %s in problems %s" %
1226                   (where, expected_exception, r.problems))
1227
1228
1229 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1230     def setUp(self):
1231         return self.publish_one()
1232
1233
1234     def test_check_good(self):
1235         d = self._fn.check(Monitor())
1236         d.addCallback(self.check_good, "test_check_good")
1237         return d
1238
1239     def test_check_no_shares(self):
1240         for shares in self._storage._peers.values():
1241             shares.clear()
1242         d = self._fn.check(Monitor())
1243         d.addCallback(self.check_bad, "test_check_no_shares")
1244         return d
1245
1246     def test_check_not_enough_shares(self):
1247         for shares in self._storage._peers.values():
1248             for shnum in shares.keys():
1249                 if shnum > 0:
1250                     del shares[shnum]
1251         d = self._fn.check(Monitor())
1252         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1253         return d
1254
1255     def test_check_all_bad_sig(self):
1256         corrupt(None, self._storage, 1) # bad sig
1257         d = self._fn.check(Monitor())
1258         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1259         return d
1260
1261     def test_check_all_bad_blocks(self):
1262         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1263         # the Checker won't notice this.. it doesn't look at actual data
1264         d = self._fn.check(Monitor())
1265         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1266         return d
1267
1268     def test_verify_good(self):
1269         d = self._fn.check(Monitor(), verify=True)
1270         d.addCallback(self.check_good, "test_verify_good")
1271         return d
1272
1273     def test_verify_all_bad_sig(self):
1274         corrupt(None, self._storage, 1) # bad sig
1275         d = self._fn.check(Monitor(), verify=True)
1276         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1277         return d
1278
1279     def test_verify_one_bad_sig(self):
1280         corrupt(None, self._storage, 1, [9]) # bad sig
1281         d = self._fn.check(Monitor(), verify=True)
1282         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1283         return d
1284
1285     def test_verify_one_bad_block(self):
1286         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1287         # the Verifier *will* notice this, since it examines every byte
1288         d = self._fn.check(Monitor(), verify=True)
1289         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1290         d.addCallback(self.check_expected_failure,
1291                       CorruptShareError, "block hash tree failure",
1292                       "test_verify_one_bad_block")
1293         return d
1294
1295     def test_verify_one_bad_sharehash(self):
1296         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1297         d = self._fn.check(Monitor(), verify=True)
1298         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1299         d.addCallback(self.check_expected_failure,
1300                       CorruptShareError, "corrupt hashes",
1301                       "test_verify_one_bad_sharehash")
1302         return d
1303
1304     def test_verify_one_bad_encprivkey(self):
1305         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1306         d = self._fn.check(Monitor(), verify=True)
1307         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1308         d.addCallback(self.check_expected_failure,
1309                       CorruptShareError, "invalid privkey",
1310                       "test_verify_one_bad_encprivkey")
1311         return d
1312
1313     def test_verify_one_bad_encprivkey_uncheckable(self):
1314         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1315         readonly_fn = self._fn.get_readonly()
1316         # a read-only node has no way to validate the privkey
1317         d = readonly_fn.check(Monitor(), verify=True)
1318         d.addCallback(self.check_good,
1319                       "test_verify_one_bad_encprivkey_uncheckable")
1320         return d
1321
1322 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1323
1324     def get_shares(self, s):
1325         all_shares = {} # maps (peerid, shnum) to share data
1326         for peerid in s._peers:
1327             shares = s._peers[peerid]
1328             for shnum in shares:
1329                 data = shares[shnum]
1330                 all_shares[ (peerid, shnum) ] = data
1331         return all_shares
1332
1333     def copy_shares(self, ignored=None):
1334         self.old_shares.append(self.get_shares(self._storage))
1335
1336     def test_repair_nop(self):
1337         self.old_shares = []
1338         d = self.publish_one()
1339         d.addCallback(self.copy_shares)
1340         d.addCallback(lambda res: self._fn.check(Monitor()))
1341         d.addCallback(lambda check_results: self._fn.repair(check_results))
1342         def _check_results(rres):
1343             self.failUnless(IRepairResults.providedBy(rres))
1344             # TODO: examine results
1345
1346             self.copy_shares()
1347
1348             initial_shares = self.old_shares[0]
1349             new_shares = self.old_shares[1]
1350             # TODO: this really shouldn't change anything. When we implement
1351             # a "minimal-bandwidth" repairer", change this test to assert:
1352             #self.failUnlessEqual(new_shares, initial_shares)
1353
1354             # all shares should be in the same place as before
1355             self.failUnlessEqual(set(initial_shares.keys()),
1356                                  set(new_shares.keys()))
1357             # but they should all be at a newer seqnum. The IV will be
1358             # different, so the roothash will be too.
1359             for key in initial_shares:
1360                 (version0,
1361                  seqnum0,
1362                  root_hash0,
1363                  IV0,
1364                  k0, N0, segsize0, datalen0,
1365                  o0) = unpack_header(initial_shares[key])
1366                 (version1,
1367                  seqnum1,
1368                  root_hash1,
1369                  IV1,
1370                  k1, N1, segsize1, datalen1,
1371                  o1) = unpack_header(new_shares[key])
1372                 self.failUnlessEqual(version0, version1)
1373                 self.failUnlessEqual(seqnum0+1, seqnum1)
1374                 self.failUnlessEqual(k0, k1)
1375                 self.failUnlessEqual(N0, N1)
1376                 self.failUnlessEqual(segsize0, segsize1)
1377                 self.failUnlessEqual(datalen0, datalen1)
1378         d.addCallback(_check_results)
1379         return d
1380
1381     def failIfSharesChanged(self, ignored=None):
1382         old_shares = self.old_shares[-2]
1383         current_shares = self.old_shares[-1]
1384         self.failUnlessEqual(old_shares, current_shares)
1385
1386     def test_merge(self):
1387         self.old_shares = []
1388         d = self.publish_multiple()
1389         # repair will refuse to merge multiple highest seqnums unless you
1390         # pass force=True
1391         d.addCallback(lambda res:
1392                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1393                                           1:4,3:4,5:4,7:4,9:4}))
1394         d.addCallback(self.copy_shares)
1395         d.addCallback(lambda res: self._fn.check(Monitor()))
1396         def _try_repair(check_results):
1397             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1398             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1399                                  self._fn.repair, check_results)
1400             d2.addCallback(self.copy_shares)
1401             d2.addCallback(self.failIfSharesChanged)
1402             d2.addCallback(lambda res: check_results)
1403             return d2
1404         d.addCallback(_try_repair)
1405         d.addCallback(lambda check_results:
1406                       self._fn.repair(check_results, force=True))
1407         # this should give us 10 shares of the highest roothash
1408         def _check_repair_results(rres):
1409             pass # TODO
1410         d.addCallback(_check_repair_results)
1411         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1412         def _check_smap(smap):
1413             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1414             self.failIf(smap.unrecoverable_versions())
1415             # now, which should have won?
1416             roothash_s4a = self.get_roothash_for(3)
1417             roothash_s4b = self.get_roothash_for(4)
1418             if roothash_s4b > roothash_s4a:
1419                 expected_contents = self.CONTENTS[4]
1420             else:
1421                 expected_contents = self.CONTENTS[3]
1422             new_versionid = smap.best_recoverable_version()
1423             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1424             d2 = self._fn.download_version(smap, new_versionid)
1425             d2.addCallback(self.failUnlessEqual, expected_contents)
1426             return d2
1427         d.addCallback(_check_smap)
1428         return d
1429
1430     def test_non_merge(self):
1431         self.old_shares = []
1432         d = self.publish_multiple()
1433         # repair should not refuse a repair that doesn't need to merge. In
1434         # this case, we combine v2 with v3. The repair should ignore v2 and
1435         # copy v3 into a new v5.
1436         d.addCallback(lambda res:
1437                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1438                                           1:3,3:3,5:3,7:3,9:3}))
1439         d.addCallback(lambda res: self._fn.check(Monitor()))
1440         d.addCallback(lambda check_results: self._fn.repair(check_results))
1441         # this should give us 10 shares of v3
1442         def _check_repair_results(rres):
1443             pass # TODO
1444         d.addCallback(_check_repair_results)
1445         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1446         def _check_smap(smap):
1447             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1448             self.failIf(smap.unrecoverable_versions())
1449             # now, which should have won?
1450             roothash_s4a = self.get_roothash_for(3)
1451             expected_contents = self.CONTENTS[3]
1452             new_versionid = smap.best_recoverable_version()
1453             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1454             d2 = self._fn.download_version(smap, new_versionid)
1455             d2.addCallback(self.failUnlessEqual, expected_contents)
1456             return d2
1457         d.addCallback(_check_smap)
1458         return d
1459
1460     def get_roothash_for(self, index):
1461         # return the roothash for the first share we see in the saved set
1462         shares = self._copied_shares[index]
1463         for peerid in shares:
1464             for shnum in shares[peerid]:
1465                 share = shares[peerid][shnum]
1466                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1467                           unpack_header(share)
1468                 return root_hash
1469
1470     def test_check_and_repair_readcap(self):
1471         # we can't currently repair from a mutable readcap: #625
1472         self.old_shares = []
1473         d = self.publish_one()
1474         d.addCallback(self.copy_shares)
1475         def _get_readcap(res):
1476             self._fn3 = self._fn.get_readonly()
1477             # also delete some shares
1478             for peerid,shares in self._storage._peers.items():
1479                 shares.pop(0, None)
1480         d.addCallback(_get_readcap)
1481         d.addCallback(lambda res: self._fn3.check_and_repair(Monitor()))
1482         def _check_results(crr):
1483             self.failUnless(ICheckAndRepairResults.providedBy(crr))
1484             # we should detect the unhealthy, but skip over mutable-readcap
1485             # repairs until #625 is fixed
1486             self.failIf(crr.get_pre_repair_results().is_healthy())
1487             self.failIf(crr.get_repair_attempted())
1488             self.failIf(crr.get_post_repair_results().is_healthy())
1489         d.addCallback(_check_results)
1490         return d
1491
1492 class MultipleEncodings(unittest.TestCase):
1493     def setUp(self):
1494         self.CONTENTS = "New contents go here"
1495         num_peers = 20
1496         self._client = FakeClient(num_peers)
1497         self._storage = self._client._storage
1498         d = self._client.create_mutable_file(self.CONTENTS)
1499         def _created(node):
1500             self._fn = node
1501         d.addCallback(_created)
1502         return d
1503
1504     def _encode(self, k, n, data):
1505         # encode 'data' into a peerid->shares dict.
1506
1507         fn2 = FastMutableFileNode(self._client)
1508         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1509         # and _fingerprint
1510         fn = self._fn
1511         fn2.init_from_uri(fn.get_uri())
1512         # then we copy over other fields that are normally fetched from the
1513         # existing shares
1514         fn2._pubkey = fn._pubkey
1515         fn2._privkey = fn._privkey
1516         fn2._encprivkey = fn._encprivkey
1517         # and set the encoding parameters to something completely different
1518         fn2._required_shares = k
1519         fn2._total_shares = n
1520
1521         s = self._client._storage
1522         s._peers = {} # clear existing storage
1523         p2 = Publish(fn2, None)
1524         d = p2.publish(data)
1525         def _published(res):
1526             shares = s._peers
1527             s._peers = {}
1528             return shares
1529         d.addCallback(_published)
1530         return d
1531
1532     def make_servermap(self, mode=MODE_READ, oldmap=None):
1533         if oldmap is None:
1534             oldmap = ServerMap()
1535         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1536         d = smu.update()
1537         return d
1538
1539     def test_multiple_encodings(self):
1540         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1541         # then mix up the shares, to make sure that download survives seeing
1542         # a variety of encodings. This is actually kind of tricky to set up.
1543
1544         contents1 = "Contents for encoding 1 (3-of-10) go here"
1545         contents2 = "Contents for encoding 2 (4-of-9) go here"
1546         contents3 = "Contents for encoding 3 (4-of-7) go here"
1547
1548         # we make a retrieval object that doesn't know what encoding
1549         # parameters to use
1550         fn3 = FastMutableFileNode(self._client)
1551         fn3.init_from_uri(self._fn.get_uri())
1552
1553         # now we upload a file through fn1, and grab its shares
1554         d = self._encode(3, 10, contents1)
1555         def _encoded_1(shares):
1556             self._shares1 = shares
1557         d.addCallback(_encoded_1)
1558         d.addCallback(lambda res: self._encode(4, 9, contents2))
1559         def _encoded_2(shares):
1560             self._shares2 = shares
1561         d.addCallback(_encoded_2)
1562         d.addCallback(lambda res: self._encode(4, 7, contents3))
1563         def _encoded_3(shares):
1564             self._shares3 = shares
1565         d.addCallback(_encoded_3)
1566
1567         def _merge(res):
1568             log.msg("merging sharelists")
1569             # we merge the shares from the two sets, leaving each shnum in
1570             # its original location, but using a share from set1 or set2
1571             # according to the following sequence:
1572             #
1573             #  4-of-9  a  s2
1574             #  4-of-9  b  s2
1575             #  4-of-7  c   s3
1576             #  4-of-9  d  s2
1577             #  3-of-9  e s1
1578             #  3-of-9  f s1
1579             #  3-of-9  g s1
1580             #  4-of-9  h  s2
1581             #
1582             # so that neither form can be recovered until fetch [f], at which
1583             # point version-s1 (the 3-of-10 form) should be recoverable. If
1584             # the implementation latches on to the first version it sees,
1585             # then s2 will be recoverable at fetch [g].
1586
1587             # Later, when we implement code that handles multiple versions,
1588             # we can use this framework to assert that all recoverable
1589             # versions are retrieved, and test that 'epsilon' does its job
1590
1591             places = [2, 2, 3, 2, 1, 1, 1, 2]
1592
1593             sharemap = {}
1594             sb = self._client.get_storage_broker()
1595
1596             for peerid in sorted(sb.get_all_serverids()):
1597                 peerid_s = shortnodeid_b2a(peerid)
1598                 for shnum in self._shares1.get(peerid, {}):
1599                     if shnum < len(places):
1600                         which = places[shnum]
1601                     else:
1602                         which = "x"
1603                     self._client._storage._peers[peerid] = peers = {}
1604                     in_1 = shnum in self._shares1[peerid]
1605                     in_2 = shnum in self._shares2.get(peerid, {})
1606                     in_3 = shnum in self._shares3.get(peerid, {})
1607                     #print peerid_s, shnum, which, in_1, in_2, in_3
1608                     if which == 1:
1609                         if in_1:
1610                             peers[shnum] = self._shares1[peerid][shnum]
1611                             sharemap[shnum] = peerid
1612                     elif which == 2:
1613                         if in_2:
1614                             peers[shnum] = self._shares2[peerid][shnum]
1615                             sharemap[shnum] = peerid
1616                     elif which == 3:
1617                         if in_3:
1618                             peers[shnum] = self._shares3[peerid][shnum]
1619                             sharemap[shnum] = peerid
1620
1621             # we don't bother placing any other shares
1622             # now sort the sequence so that share 0 is returned first
1623             new_sequence = [sharemap[shnum]
1624                             for shnum in sorted(sharemap.keys())]
1625             self._client._storage._sequence = new_sequence
1626             log.msg("merge done")
1627         d.addCallback(_merge)
1628         d.addCallback(lambda res: fn3.download_best_version())
1629         def _retrieved(new_contents):
1630             # the current specified behavior is "first version recoverable"
1631             self.failUnlessEqual(new_contents, contents1)
1632         d.addCallback(_retrieved)
1633         return d
1634
1635
1636 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1637
1638     def setUp(self):
1639         return self.publish_multiple()
1640
1641     def test_multiple_versions(self):
1642         # if we see a mix of versions in the grid, download_best_version
1643         # should get the latest one
1644         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1645         d = self._fn.download_best_version()
1646         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1647         # and the checker should report problems
1648         d.addCallback(lambda res: self._fn.check(Monitor()))
1649         d.addCallback(self.check_bad, "test_multiple_versions")
1650
1651         # but if everything is at version 2, that's what we should download
1652         d.addCallback(lambda res:
1653                       self._set_versions(dict([(i,2) for i in range(10)])))
1654         d.addCallback(lambda res: self._fn.download_best_version())
1655         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1656         # if exactly one share is at version 3, we should still get v2
1657         d.addCallback(lambda res:
1658                       self._set_versions({0:3}))
1659         d.addCallback(lambda res: self._fn.download_best_version())
1660         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1661         # but the servermap should see the unrecoverable version. This
1662         # depends upon the single newer share being queried early.
1663         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1664         def _check_smap(smap):
1665             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1666             newer = smap.unrecoverable_newer_versions()
1667             self.failUnlessEqual(len(newer), 1)
1668             verinfo, health = newer.items()[0]
1669             self.failUnlessEqual(verinfo[0], 4)
1670             self.failUnlessEqual(health, (1,3))
1671             self.failIf(smap.needs_merge())
1672         d.addCallback(_check_smap)
1673         # if we have a mix of two parallel versions (s4a and s4b), we could
1674         # recover either
1675         d.addCallback(lambda res:
1676                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1677                                           1:4,3:4,5:4,7:4,9:4}))
1678         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1679         def _check_smap_mixed(smap):
1680             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1681             newer = smap.unrecoverable_newer_versions()
1682             self.failUnlessEqual(len(newer), 0)
1683             self.failUnless(smap.needs_merge())
1684         d.addCallback(_check_smap_mixed)
1685         d.addCallback(lambda res: self._fn.download_best_version())
1686         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1687                                                   res == self.CONTENTS[4]))
1688         return d
1689
1690     def test_replace(self):
1691         # if we see a mix of versions in the grid, we should be able to
1692         # replace them all with a newer version
1693
1694         # if exactly one share is at version 3, we should download (and
1695         # replace) v2, and the result should be v4. Note that the index we
1696         # give to _set_versions is different than the sequence number.
1697         target = dict([(i,2) for i in range(10)]) # seqnum3
1698         target[0] = 3 # seqnum4
1699         self._set_versions(target)
1700
1701         def _modify(oldversion, servermap, first_time):
1702             return oldversion + " modified"
1703         d = self._fn.modify(_modify)
1704         d.addCallback(lambda res: self._fn.download_best_version())
1705         expected = self.CONTENTS[2] + " modified"
1706         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1707         # and the servermap should indicate that the outlier was replaced too
1708         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1709         def _check_smap(smap):
1710             self.failUnlessEqual(smap.highest_seqnum(), 5)
1711             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1712             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1713         d.addCallback(_check_smap)
1714         return d
1715
1716
1717 class Utils(unittest.TestCase):
1718     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1719         # we compare this against sets of integers
1720         x = set(range(x_start, x_start+x_length))
1721         y = set(range(y_start, y_start+y_length))
1722         should_be_inside = x.issubset(y)
1723         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1724                                                          y_start, y_length),
1725                              str((x_start, x_length, y_start, y_length)))
1726
1727     def test_cache_inside(self):
1728         c = ResponseCache()
1729         x_start = 10
1730         x_length = 5
1731         for y_start in range(8, 17):
1732             for y_length in range(8):
1733                 self._do_inside(c, x_start, x_length, y_start, y_length)
1734
1735     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1736         # we compare this against sets of integers
1737         x = set(range(x_start, x_start+x_length))
1738         y = set(range(y_start, y_start+y_length))
1739         overlap = bool(x.intersection(y))
1740         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1741                                                       y_start, y_length),
1742                              str((x_start, x_length, y_start, y_length)))
1743
1744     def test_cache_overlap(self):
1745         c = ResponseCache()
1746         x_start = 10
1747         x_length = 5
1748         for y_start in range(8, 17):
1749             for y_length in range(8):
1750                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1751
1752     def test_cache(self):
1753         c = ResponseCache()
1754         # xdata = base62.b2a(os.urandom(100))[:100]
1755         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1756         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1757         nope = (None, None)
1758         c.add("v1", 1, 0, xdata, "time0")
1759         c.add("v1", 1, 2000, ydata, "time1")
1760         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1761         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1762         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1763         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1764         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1765         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1766         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1767         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1768         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1769         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1770         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1771         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1772         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1773         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1774         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1775         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1776         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1777         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1778
1779         # optional: join fragments
1780         c = ResponseCache()
1781         c.add("v1", 1, 0, xdata[:10], "time0")
1782         c.add("v1", 1, 10, xdata[10:20], "time1")
1783         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1784
1785 class Exceptions(unittest.TestCase):
1786     def test_repr(self):
1787         nmde = NeedMoreDataError(100, 50, 100)
1788         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1789         ucwe = UncoordinatedWriteError()
1790         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1791
1792 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1793 # instances which always succeed. So we need a less-fake one.
1794
1795 class IntentionalError(Exception):
1796     pass
1797
1798 class LocalWrapper:
1799     def __init__(self, original):
1800         self.original = original
1801         self.broken = False
1802         self.post_call_notifier = None
1803     def callRemote(self, methname, *args, **kwargs):
1804         def _call():
1805             if self.broken:
1806                 raise IntentionalError("I was asked to break")
1807             meth = getattr(self.original, "remote_" + methname)
1808             return meth(*args, **kwargs)
1809         d = fireEventually()
1810         d.addCallback(lambda res: _call())
1811         if self.post_call_notifier:
1812             d.addCallback(self.post_call_notifier, methname)
1813         return d
1814
1815 class LessFakeClient(FakeClient):
1816
1817     def __init__(self, basedir, num_peers=10):
1818         self._num_peers = num_peers
1819         peerids = [tagged_hash("peerid", "%d" % i)[:20]
1820                    for i in range(self._num_peers)]
1821         self.storage_broker = StorageFarmBroker(None, True)
1822         for peerid in peerids:
1823             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1824             make_dirs(peerdir)
1825             ss = StorageServer(peerdir, peerid)
1826             lw = LocalWrapper(ss)
1827             self.storage_broker.test_add_server(peerid, lw)
1828         self.nodeid = "fakenodeid"
1829
1830
1831 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1832     def test_publish_surprise(self):
1833         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1834         self.client = LessFakeClient(basedir)
1835         d = self.client.create_mutable_file("contents 1")
1836         def _created(n):
1837             d = defer.succeed(None)
1838             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1839             def _got_smap1(smap):
1840                 # stash the old state of the file
1841                 self.old_map = smap
1842             d.addCallback(_got_smap1)
1843             # then modify the file, leaving the old map untouched
1844             d.addCallback(lambda res: log.msg("starting winning write"))
1845             d.addCallback(lambda res: n.overwrite("contents 2"))
1846             # now attempt to modify the file with the old servermap. This
1847             # will look just like an uncoordinated write, in which every
1848             # single share got updated between our mapupdate and our publish
1849             d.addCallback(lambda res: log.msg("starting doomed write"))
1850             d.addCallback(lambda res:
1851                           self.shouldFail(UncoordinatedWriteError,
1852                                           "test_publish_surprise", None,
1853                                           n.upload,
1854                                           "contents 2a", self.old_map))
1855             return d
1856         d.addCallback(_created)
1857         return d
1858
1859     def test_retrieve_surprise(self):
1860         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1861         self.client = LessFakeClient(basedir)
1862         d = self.client.create_mutable_file("contents 1")
1863         def _created(n):
1864             d = defer.succeed(None)
1865             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1866             def _got_smap1(smap):
1867                 # stash the old state of the file
1868                 self.old_map = smap
1869             d.addCallback(_got_smap1)
1870             # then modify the file, leaving the old map untouched
1871             d.addCallback(lambda res: log.msg("starting winning write"))
1872             d.addCallback(lambda res: n.overwrite("contents 2"))
1873             # now attempt to retrieve the old version with the old servermap.
1874             # This will look like someone has changed the file since we
1875             # updated the servermap.
1876             d.addCallback(lambda res: n._cache._clear())
1877             d.addCallback(lambda res: log.msg("starting doomed read"))
1878             d.addCallback(lambda res:
1879                           self.shouldFail(NotEnoughSharesError,
1880                                           "test_retrieve_surprise",
1881                                           "ran out of peers: have 0 shares (k=3)",
1882                                           n.download_version,
1883                                           self.old_map,
1884                                           self.old_map.best_recoverable_version(),
1885                                           ))
1886             return d
1887         d.addCallback(_created)
1888         return d
1889
1890     def test_unexpected_shares(self):
1891         # upload the file, take a servermap, shut down one of the servers,
1892         # upload it again (causing shares to appear on a new server), then
1893         # upload using the old servermap. The last upload should fail with an
1894         # UncoordinatedWriteError, because of the shares that didn't appear
1895         # in the servermap.
1896         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1897         self.client = LessFakeClient(basedir)
1898         d = self.client.create_mutable_file("contents 1")
1899         def _created(n):
1900             d = defer.succeed(None)
1901             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1902             def _got_smap1(smap):
1903                 # stash the old state of the file
1904                 self.old_map = smap
1905                 # now shut down one of the servers
1906                 peer0 = list(smap.make_sharemap()[0])[0]
1907                 self.client.debug_remove_connection(peer0)
1908                 # then modify the file, leaving the old map untouched
1909                 log.msg("starting winning write")
1910                 return n.overwrite("contents 2")
1911             d.addCallback(_got_smap1)
1912             # now attempt to modify the file with the old servermap. This
1913             # will look just like an uncoordinated write, in which every
1914             # single share got updated between our mapupdate and our publish
1915             d.addCallback(lambda res: log.msg("starting doomed write"))
1916             d.addCallback(lambda res:
1917                           self.shouldFail(UncoordinatedWriteError,
1918                                           "test_surprise", None,
1919                                           n.upload,
1920                                           "contents 2a", self.old_map))
1921             return d
1922         d.addCallback(_created)
1923         return d
1924
1925     def test_bad_server(self):
1926         # Break one server, then create the file: the initial publish should
1927         # complete with an alternate server. Breaking a second server should
1928         # not prevent an update from succeeding either.
1929         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1930         self.client = LessFakeClient(basedir, 20)
1931         # to make sure that one of the initial peers is broken, we have to
1932         # get creative. We create the keys, so we can figure out the storage
1933         # index, but we hold off on doing the initial publish until we've
1934         # broken the server on which the first share wants to be stored.
1935         n = FastMutableFileNode(self.client)
1936         d = defer.succeed(None)
1937         d.addCallback(n._generate_pubprivkeys, keysize=522)
1938         d.addCallback(n._generated)
1939         def _break_peer0(res):
1940             si = n.get_storage_index()
1941             peerlist = self.client.storage_broker.get_servers_for_index(si)
1942             peerid0, connection0 = peerlist[0]
1943             peerid1, connection1 = peerlist[1]
1944             connection0.broken = True
1945             self.connection1 = connection1
1946         d.addCallback(_break_peer0)
1947         # now let the initial publish finally happen
1948         d.addCallback(lambda res: n._upload("contents 1", None))
1949         # that ought to work
1950         d.addCallback(lambda res: n.download_best_version())
1951         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1952         # now break the second peer
1953         def _break_peer1(res):
1954             self.connection1.broken = True
1955         d.addCallback(_break_peer1)
1956         d.addCallback(lambda res: n.overwrite("contents 2"))
1957         # that ought to work too
1958         d.addCallback(lambda res: n.download_best_version())
1959         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1960         def _explain_error(f):
1961             print f
1962             if f.check(NotEnoughServersError):
1963                 print "first_error:", f.value.first_error
1964             return f
1965         d.addErrback(_explain_error)
1966         return d
1967
1968     def test_bad_server_overlap(self):
1969         # like test_bad_server, but with no extra unused servers to fall back
1970         # upon. This means that we must re-use a server which we've already
1971         # used. If we don't remember the fact that we sent them one share
1972         # already, we'll mistakenly think we're experiencing an
1973         # UncoordinatedWriteError.
1974
1975         # Break one server, then create the file: the initial publish should
1976         # complete with an alternate server. Breaking a second server should
1977         # not prevent an update from succeeding either.
1978         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1979         self.client = LessFakeClient(basedir, 10)
1980         sb = self.client.get_storage_broker()
1981
1982         peerids = list(sb.get_all_serverids())
1983         self.client.debug_break_connection(peerids[0])
1984
1985         d = self.client.create_mutable_file("contents 1")
1986         def _created(n):
1987             d = n.download_best_version()
1988             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1989             # now break one of the remaining servers
1990             def _break_second_server(res):
1991                 self.client.debug_break_connection(peerids[1])
1992             d.addCallback(_break_second_server)
1993             d.addCallback(lambda res: n.overwrite("contents 2"))
1994             # that ought to work too
1995             d.addCallback(lambda res: n.download_best_version())
1996             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1997             return d
1998         d.addCallback(_created)
1999         return d
2000
2001     def test_publish_all_servers_bad(self):
2002         # Break all servers: the publish should fail
2003         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
2004         self.client = LessFakeClient(basedir, 20)
2005         sb = self.client.get_storage_broker()
2006         for peerid in sb.get_all_serverids():
2007             self.client.debug_break_connection(peerid)
2008         d = self.shouldFail(NotEnoughServersError,
2009                             "test_publish_all_servers_bad",
2010                             "Ran out of non-bad servers",
2011                             self.client.create_mutable_file, "contents")
2012         return d
2013
2014     def test_publish_no_servers(self):
2015         # no servers at all: the publish should fail
2016         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
2017         self.client = LessFakeClient(basedir, 0)
2018         d = self.shouldFail(NotEnoughServersError,
2019                             "test_publish_no_servers",
2020                             "Ran out of non-bad servers",
2021                             self.client.create_mutable_file, "contents")
2022         return d
2023     test_publish_no_servers.timeout = 30
2024
2025
2026     def test_privkey_query_error(self):
2027         # when a servermap is updated with MODE_WRITE, it tries to get the
2028         # privkey. Something might go wrong during this query attempt.
2029         self.client = FakeClient(20)
2030         # we need some contents that are large enough to push the privkey out
2031         # of the early part of the file
2032         LARGE = "These are Larger contents" * 200 # about 5KB
2033         d = self.client.create_mutable_file(LARGE)
2034         def _created(n):
2035             self.uri = n.get_uri()
2036             self.n2 = self.client.create_node_from_uri(self.uri)
2037             # we start by doing a map update to figure out which is the first
2038             # server.
2039             return n.get_servermap(MODE_WRITE)
2040         d.addCallback(_created)
2041         d.addCallback(lambda res: fireEventually(res))
2042         def _got_smap1(smap):
2043             peer0 = list(smap.make_sharemap()[0])[0]
2044             # we tell the server to respond to this peer first, so that it
2045             # will be asked for the privkey first
2046             self.client._storage._sequence = [peer0]
2047             # now we make the peer fail their second query
2048             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2049         d.addCallback(_got_smap1)
2050         # now we update a servermap from a new node (which doesn't have the
2051         # privkey yet, forcing it to use a separate privkey query). Each
2052         # query response will trigger a privkey query, and since we're using
2053         # _sequence to make the peer0 response come back first, we'll send it
2054         # a privkey query first, and _sequence will again ensure that the
2055         # peer0 query will also come back before the others, and then
2056         # _special_answers will make sure that the query raises an exception.
2057         # The whole point of these hijinks is to exercise the code in
2058         # _privkey_query_failed. Note that the map-update will succeed, since
2059         # we'll just get a copy from one of the other shares.
2060         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2061         # Using FakeStorage._sequence means there will be read requests still
2062         # floating around.. wait for them to retire
2063         def _cancel_timer(res):
2064             if self.client._storage._pending_timer:
2065                 self.client._storage._pending_timer.cancel()
2066             return res
2067         d.addBoth(_cancel_timer)
2068         return d
2069
2070     def test_privkey_query_missing(self):
2071         # like test_privkey_query_error, but the shares are deleted by the
2072         # second query, instead of raising an exception.
2073         self.client = FakeClient(20)
2074         LARGE = "These are Larger contents" * 200 # about 5KB
2075         d = self.client.create_mutable_file(LARGE)
2076         def _created(n):
2077             self.uri = n.get_uri()
2078             self.n2 = self.client.create_node_from_uri(self.uri)
2079             return n.get_servermap(MODE_WRITE)
2080         d.addCallback(_created)
2081         d.addCallback(lambda res: fireEventually(res))
2082         def _got_smap1(smap):
2083             peer0 = list(smap.make_sharemap()[0])[0]
2084             self.client._storage._sequence = [peer0]
2085             self.client._storage._special_answers[peer0] = ["normal", "none"]
2086         d.addCallback(_got_smap1)
2087         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2088         def _cancel_timer(res):
2089             if self.client._storage._pending_timer:
2090                 self.client._storage._pending_timer.cancel()
2091             return res
2092         d.addBoth(_cancel_timer)
2093         return d