]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
use 522-bit RSA keys in all unit tests (except one)
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri
8 from allmydata.storage.server import StorageServer
9 from allmydata.immutable import download
10 from allmydata.util import base32, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      NotEnoughSharesError, IRepairResults
16 from allmydata.monitor import Monitor
17 from allmydata.test.common import ShouldFailMixin
18 from foolscap.api import eventually, fireEventually
19 from foolscap.logging import log
20 from allmydata.storage_client import StorageFarmBroker
21
22 from allmydata.mutable.filenode import MutableFileNode, BackoffAgent
23 from allmydata.mutable.common import ResponseCache, \
24      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
25      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
26      NotEnoughServersError, CorruptShareError
27 from allmydata.mutable.retrieve import Retrieve
28 from allmydata.mutable.publish import Publish
29 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
30 from allmydata.mutable.layout import unpack_header, unpack_share
31 from allmydata.mutable.repairer import MustForceRepairError
32
33 import common_util as testutil
34
35 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
36 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
37 # of this.
38
39 class FastMutableFileNode(MutableFileNode):
40     SIGNATURE_KEY_SIZE = 522
41
42 # this "FakeStorage" exists to put the share data in RAM and avoid using real
43 # network connections, both to speed up the tests and to reduce the amount of
44 # non-mutable.py code being exercised.
45
46 class FakeStorage:
47     # this class replaces the collection of storage servers, allowing the
48     # tests to examine and manipulate the published shares. It also lets us
49     # control the order in which read queries are answered, to exercise more
50     # of the error-handling code in Retrieve .
51     #
52     # Note that we ignore the storage index: this FakeStorage instance can
53     # only be used for a single storage index.
54
55
56     def __init__(self):
57         self._peers = {}
58         # _sequence is used to cause the responses to occur in a specific
59         # order. If it is in use, then we will defer queries instead of
60         # answering them right away, accumulating the Deferreds in a dict. We
61         # don't know exactly how many queries we'll get, so exactly one
62         # second after the first query arrives, we will release them all (in
63         # order).
64         self._sequence = None
65         self._pending = {}
66         self._pending_timer = None
67         self._special_answers = {}
68
69     def read(self, peerid, storage_index):
70         shares = self._peers.get(peerid, {})
71         if self._special_answers.get(peerid, []):
72             mode = self._special_answers[peerid].pop(0)
73             if mode == "fail":
74                 shares = failure.Failure(IntentionalError())
75             elif mode == "none":
76                 shares = {}
77             elif mode == "normal":
78                 pass
79         if self._sequence is None:
80             return defer.succeed(shares)
81         d = defer.Deferred()
82         if not self._pending:
83             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
84         self._pending[peerid] = (d, shares)
85         return d
86
87     def _fire_readers(self):
88         self._pending_timer = None
89         pending = self._pending
90         self._pending = {}
91         extra = []
92         for peerid in self._sequence:
93             if peerid in pending:
94                 d, shares = pending.pop(peerid)
95                 eventually(d.callback, shares)
96         for (d, shares) in pending.values():
97             eventually(d.callback, shares)
98
99     def write(self, peerid, storage_index, shnum, offset, data):
100         if peerid not in self._peers:
101             self._peers[peerid] = {}
102         shares = self._peers[peerid]
103         f = StringIO()
104         f.write(shares.get(shnum, ""))
105         f.seek(offset)
106         f.write(data)
107         shares[shnum] = f.getvalue()
108
109
110 class FakeStorageServer:
111     def __init__(self, peerid, storage):
112         self.peerid = peerid
113         self.storage = storage
114         self.queries = 0
115     def callRemote(self, methname, *args, **kwargs):
116         def _call():
117             meth = getattr(self, methname)
118             return meth(*args, **kwargs)
119         d = fireEventually()
120         d.addCallback(lambda res: _call())
121         return d
122     def callRemoteOnly(self, methname, *args, **kwargs):
123         d = self.callRemote(methname, *args, **kwargs)
124         d.addBoth(lambda ignore: None)
125         pass
126
127     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
128         pass
129
130     def slot_readv(self, storage_index, shnums, readv):
131         d = self.storage.read(self.peerid, storage_index)
132         def _read(shares):
133             response = {}
134             for shnum in shares:
135                 if shnums and shnum not in shnums:
136                     continue
137                 vector = response[shnum] = []
138                 for (offset, length) in readv:
139                     assert isinstance(offset, (int, long)), offset
140                     assert isinstance(length, (int, long)), length
141                     vector.append(shares[shnum][offset:offset+length])
142             return response
143         d.addCallback(_read)
144         return d
145
146     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
147                                         tw_vectors, read_vector):
148         # always-pass: parrot the test vectors back to them.
149         readv = {}
150         for shnum, (testv, writev, new_length) in tw_vectors.items():
151             for (offset, length, op, specimen) in testv:
152                 assert op in ("le", "eq", "ge")
153             # TODO: this isn't right, the read is controlled by read_vector,
154             # not by testv
155             readv[shnum] = [ specimen
156                              for (offset, length, op, specimen)
157                              in testv ]
158             for (offset, data) in writev:
159                 self.storage.write(self.peerid, storage_index, shnum,
160                                    offset, data)
161         answer = (True, readv)
162         return fireEventually(answer)
163
164
165 # our "FakeClient" has just enough functionality of the real Client to let
166 # the tests run.
167
168 class FakeClient:
169     mutable_file_node_class = FastMutableFileNode
170
171     def __init__(self, num_peers=10):
172         self._storage = FakeStorage()
173         self._num_peers = num_peers
174         peerids = [tagged_hash("peerid", "%d" % i)[:20]
175                    for i in range(self._num_peers)]
176         self.nodeid = "fakenodeid"
177         self.storage_broker = StorageFarmBroker(None, True)
178         for peerid in peerids:
179             fss = FakeStorageServer(peerid, self._storage)
180             self.storage_broker.test_add_server(peerid, fss)
181
182     def get_storage_broker(self):
183         return self.storage_broker
184     def debug_break_connection(self, peerid):
185         self.storage_broker.test_servers[peerid].broken = True
186     def debug_remove_connection(self, peerid):
187         self.storage_broker.test_servers.pop(peerid)
188     def debug_get_connection(self, peerid):
189         return self.storage_broker.test_servers[peerid]
190
191     def get_encoding_parameters(self):
192         return {"k": 3, "n": 10}
193
194     def log(self, msg, **kw):
195         return log.msg(msg, **kw)
196
197     def get_renewal_secret(self):
198         return "I hereby permit you to renew my files"
199     def get_cancel_secret(self):
200         return "I hereby permit you to cancel my leases"
201
202     def create_mutable_file(self, contents=""):
203         n = self.mutable_file_node_class(self)
204         d = n.create(contents)
205         d.addCallback(lambda res: n)
206         return d
207
208     def get_history(self):
209         return None
210
211     def create_node_from_uri(self, u):
212         u = IURI(u)
213         assert IMutableFileURI.providedBy(u), u
214         res = self.mutable_file_node_class(self).init_from_uri(u)
215         return res
216
217     def upload(self, uploadable):
218         assert IUploadable.providedBy(uploadable)
219         d = uploadable.get_size()
220         d.addCallback(lambda length: uploadable.read(length))
221         #d.addCallback(self.create_mutable_file)
222         def _got_data(datav):
223             data = "".join(datav)
224             #newnode = FastMutableFileNode(self)
225             return uri.LiteralFileURI(data)
226         d.addCallback(_got_data)
227         return d
228
229
230 def flip_bit(original, byte_offset):
231     return (original[:byte_offset] +
232             chr(ord(original[byte_offset]) ^ 0x01) +
233             original[byte_offset+1:])
234
235 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
236     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
237     # list of shnums to corrupt.
238     for peerid in s._peers:
239         shares = s._peers[peerid]
240         for shnum in shares:
241             if (shnums_to_corrupt is not None
242                 and shnum not in shnums_to_corrupt):
243                 continue
244             data = shares[shnum]
245             (version,
246              seqnum,
247              root_hash,
248              IV,
249              k, N, segsize, datalen,
250              o) = unpack_header(data)
251             if isinstance(offset, tuple):
252                 offset1, offset2 = offset
253             else:
254                 offset1 = offset
255                 offset2 = 0
256             if offset1 == "pubkey":
257                 real_offset = 107
258             elif offset1 in o:
259                 real_offset = o[offset1]
260             else:
261                 real_offset = offset1
262             real_offset = int(real_offset) + offset2 + offset_offset
263             assert isinstance(real_offset, int), offset
264             shares[shnum] = flip_bit(data, real_offset)
265     return res
266
267 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
268     # this used to be in Publish, but we removed the limit. Some of
269     # these tests test whether the new code correctly allows files
270     # larger than the limit.
271     OLD_MAX_SEGMENT_SIZE = 3500000
272     def setUp(self):
273         self.client = FakeClient()
274
275     def test_create(self):
276         d = self.client.create_mutable_file()
277         def _created(n):
278             self.failUnless(isinstance(n, FastMutableFileNode))
279             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
280             sb = self.client.get_storage_broker()
281             peer0 = sorted(sb.get_all_serverids())[0]
282             shnums = self.client._storage._peers[peer0].keys()
283             self.failUnlessEqual(len(shnums), 1)
284         d.addCallback(_created)
285         return d
286
287     def test_serialize(self):
288         n = MutableFileNode(self.client)
289         calls = []
290         def _callback(*args, **kwargs):
291             self.failUnlessEqual(args, (4,) )
292             self.failUnlessEqual(kwargs, {"foo": 5})
293             calls.append(1)
294             return 6
295         d = n._do_serialized(_callback, 4, foo=5)
296         def _check_callback(res):
297             self.failUnlessEqual(res, 6)
298             self.failUnlessEqual(calls, [1])
299         d.addCallback(_check_callback)
300
301         def _errback():
302             raise ValueError("heya")
303         d.addCallback(lambda res:
304                       self.shouldFail(ValueError, "_check_errback", "heya",
305                                       n._do_serialized, _errback))
306         return d
307
308     def test_upload_and_download(self):
309         d = self.client.create_mutable_file()
310         def _created(n):
311             d = defer.succeed(None)
312             d.addCallback(lambda res: n.get_servermap(MODE_READ))
313             d.addCallback(lambda smap: smap.dump(StringIO()))
314             d.addCallback(lambda sio:
315                           self.failUnless("3-of-10" in sio.getvalue()))
316             d.addCallback(lambda res: n.overwrite("contents 1"))
317             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
318             d.addCallback(lambda res: n.download_best_version())
319             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
320             d.addCallback(lambda res: n.get_size_of_best_version())
321             d.addCallback(lambda size:
322                           self.failUnlessEqual(size, len("contents 1")))
323             d.addCallback(lambda res: n.overwrite("contents 2"))
324             d.addCallback(lambda res: n.download_best_version())
325             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
326             d.addCallback(lambda res: n.download(download.Data()))
327             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
328             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
329             d.addCallback(lambda smap: n.upload("contents 3", smap))
330             d.addCallback(lambda res: n.download_best_version())
331             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
332             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
333             d.addCallback(lambda smap:
334                           n.download_version(smap,
335                                              smap.best_recoverable_version()))
336             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
337             # test a file that is large enough to overcome the
338             # mapupdate-to-retrieve data caching (i.e. make the shares larger
339             # than the default readsize, which is 2000 bytes). A 15kB file
340             # will have 5kB shares.
341             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
342             d.addCallback(lambda res: n.download_best_version())
343             d.addCallback(lambda res:
344                           self.failUnlessEqual(res, "large size file" * 1000))
345             return d
346         d.addCallback(_created)
347         return d
348
349     def test_create_with_initial_contents(self):
350         d = self.client.create_mutable_file("contents 1")
351         def _created(n):
352             d = n.download_best_version()
353             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
354             d.addCallback(lambda res: n.overwrite("contents 2"))
355             d.addCallback(lambda res: n.download_best_version())
356             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
357             return d
358         d.addCallback(_created)
359         return d
360
361     def test_create_with_too_large_contents(self):
362         BIG = "a" * (self.OLD_MAX_SEGMENT_SIZE + 1)
363         d = self.client.create_mutable_file(BIG)
364         def _created(n):
365             d = n.overwrite(BIG)
366             return d
367         d.addCallback(_created)
368         return d
369
370     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum, which):
371         d = n.get_servermap(MODE_READ)
372         d.addCallback(lambda servermap: servermap.best_recoverable_version())
373         d.addCallback(lambda verinfo:
374                       self.failUnlessEqual(verinfo[0], expected_seqnum, which))
375         return d
376
377     def test_modify(self):
378         def _modifier(old_contents, servermap, first_time):
379             return old_contents + "line2"
380         def _non_modifier(old_contents, servermap, first_time):
381             return old_contents
382         def _none_modifier(old_contents, servermap, first_time):
383             return None
384         def _error_modifier(old_contents, servermap, first_time):
385             raise ValueError("oops")
386         def _toobig_modifier(old_contents, servermap, first_time):
387             return "b" * (self.OLD_MAX_SEGMENT_SIZE+1)
388         calls = []
389         def _ucw_error_modifier(old_contents, servermap, first_time):
390             # simulate an UncoordinatedWriteError once
391             calls.append(1)
392             if len(calls) <= 1:
393                 raise UncoordinatedWriteError("simulated")
394             return old_contents + "line3"
395         def _ucw_error_non_modifier(old_contents, servermap, first_time):
396             # simulate an UncoordinatedWriteError once, and don't actually
397             # modify the contents on subsequent invocations
398             calls.append(1)
399             if len(calls) <= 1:
400                 raise UncoordinatedWriteError("simulated")
401             return old_contents
402
403         d = self.client.create_mutable_file("line1")
404         def _created(n):
405             d = n.modify(_modifier)
406             d.addCallback(lambda res: n.download_best_version())
407             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
409
410             d.addCallback(lambda res: n.modify(_non_modifier))
411             d.addCallback(lambda res: n.download_best_version())
412             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
413             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "non"))
414
415             d.addCallback(lambda res: n.modify(_none_modifier))
416             d.addCallback(lambda res: n.download_best_version())
417             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
418             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "none"))
419
420             d.addCallback(lambda res:
421                           self.shouldFail(ValueError, "error_modifier", None,
422                                           n.modify, _error_modifier))
423             d.addCallback(lambda res: n.download_best_version())
424             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
425             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "err"))
426
427
428             d.addCallback(lambda res: n.download_best_version())
429             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
430             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "big"))
431
432             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
433             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
434             d.addCallback(lambda res: n.download_best_version())
435             d.addCallback(lambda res: self.failUnlessEqual(res,
436                                                            "line1line2line3"))
437             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "ucw"))
438
439             def _reset_ucw_error_modifier(res):
440                 calls[:] = []
441                 return res
442             d.addCallback(_reset_ucw_error_modifier)
443
444             # in practice, this n.modify call should publish twice: the first
445             # one gets a UCWE, the second does not. But our test jig (in
446             # which the modifier raises the UCWE) skips over the first one,
447             # so in this test there will be only one publish, and the seqnum
448             # will only be one larger than the previous test, not two (i.e. 4
449             # instead of 5).
450             d.addCallback(lambda res: n.modify(_ucw_error_non_modifier))
451             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
452             d.addCallback(lambda res: n.download_best_version())
453             d.addCallback(lambda res: self.failUnlessEqual(res,
454                                                            "line1line2line3"))
455             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 4, "ucw"))
456             d.addCallback(lambda res: n.modify(_toobig_modifier))
457             return d
458         d.addCallback(_created)
459         return d
460
461     def test_modify_backoffer(self):
462         def _modifier(old_contents, servermap, first_time):
463             return old_contents + "line2"
464         calls = []
465         def _ucw_error_modifier(old_contents, servermap, first_time):
466             # simulate an UncoordinatedWriteError once
467             calls.append(1)
468             if len(calls) <= 1:
469                 raise UncoordinatedWriteError("simulated")
470             return old_contents + "line3"
471         def _always_ucw_error_modifier(old_contents, servermap, first_time):
472             raise UncoordinatedWriteError("simulated")
473         def _backoff_stopper(node, f):
474             return f
475         def _backoff_pauser(node, f):
476             d = defer.Deferred()
477             reactor.callLater(0.5, d.callback, None)
478             return d
479
480         # the give-up-er will hit its maximum retry count quickly
481         giveuper = BackoffAgent()
482         giveuper._delay = 0.1
483         giveuper.factor = 1
484
485         d = self.client.create_mutable_file("line1")
486         def _created(n):
487             d = n.modify(_modifier)
488             d.addCallback(lambda res: n.download_best_version())
489             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
490             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "m"))
491
492             d.addCallback(lambda res:
493                           self.shouldFail(UncoordinatedWriteError,
494                                           "_backoff_stopper", None,
495                                           n.modify, _ucw_error_modifier,
496                                           _backoff_stopper))
497             d.addCallback(lambda res: n.download_best_version())
498             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
499             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2, "stop"))
500
501             def _reset_ucw_error_modifier(res):
502                 calls[:] = []
503                 return res
504             d.addCallback(_reset_ucw_error_modifier)
505             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
506                                                _backoff_pauser))
507             d.addCallback(lambda res: n.download_best_version())
508             d.addCallback(lambda res: self.failUnlessEqual(res,
509                                                            "line1line2line3"))
510             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "pause"))
511
512             d.addCallback(lambda res:
513                           self.shouldFail(UncoordinatedWriteError,
514                                           "giveuper", None,
515                                           n.modify, _always_ucw_error_modifier,
516                                           giveuper.delay))
517             d.addCallback(lambda res: n.download_best_version())
518             d.addCallback(lambda res: self.failUnlessEqual(res,
519                                                            "line1line2line3"))
520             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3, "giveup"))
521
522             return d
523         d.addCallback(_created)
524         return d
525
526     def test_upload_and_download_full_size_keys(self):
527         self.client.mutable_file_node_class = MutableFileNode
528         d = self.client.create_mutable_file()
529         def _created(n):
530             d = defer.succeed(None)
531             d.addCallback(lambda res: n.get_servermap(MODE_READ))
532             d.addCallback(lambda smap: smap.dump(StringIO()))
533             d.addCallback(lambda sio:
534                           self.failUnless("3-of-10" in sio.getvalue()))
535             d.addCallback(lambda res: n.overwrite("contents 1"))
536             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
537             d.addCallback(lambda res: n.download_best_version())
538             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
539             d.addCallback(lambda res: n.overwrite("contents 2"))
540             d.addCallback(lambda res: n.download_best_version())
541             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
542             d.addCallback(lambda res: n.download(download.Data()))
543             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
544             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
545             d.addCallback(lambda smap: n.upload("contents 3", smap))
546             d.addCallback(lambda res: n.download_best_version())
547             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
548             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
549             d.addCallback(lambda smap:
550                           n.download_version(smap,
551                                              smap.best_recoverable_version()))
552             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
553             return d
554         d.addCallback(_created)
555         return d
556
557
558 class MakeShares(unittest.TestCase):
559     def test_encrypt(self):
560         c = FakeClient()
561         fn = FastMutableFileNode(c)
562         CONTENTS = "some initial contents"
563         d = fn.create(CONTENTS)
564         def _created(res):
565             p = Publish(fn, None)
566             p.salt = "SALT" * 4
567             p.readkey = "\x00" * 16
568             p.newdata = CONTENTS
569             p.required_shares = 3
570             p.total_shares = 10
571             p.setup_encoding_parameters()
572             return p._encrypt_and_encode()
573         d.addCallback(_created)
574         def _done(shares_and_shareids):
575             (shares, share_ids) = shares_and_shareids
576             self.failUnlessEqual(len(shares), 10)
577             for sh in shares:
578                 self.failUnless(isinstance(sh, str))
579                 self.failUnlessEqual(len(sh), 7)
580             self.failUnlessEqual(len(share_ids), 10)
581         d.addCallback(_done)
582         return d
583
584     def test_generate(self):
585         c = FakeClient()
586         fn = FastMutableFileNode(c)
587         CONTENTS = "some initial contents"
588         d = fn.create(CONTENTS)
589         def _created(res):
590             p = Publish(fn, None)
591             self._p = p
592             p.newdata = CONTENTS
593             p.required_shares = 3
594             p.total_shares = 10
595             p.setup_encoding_parameters()
596             p._new_seqnum = 3
597             p.salt = "SALT" * 4
598             # make some fake shares
599             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
600             p._privkey = fn.get_privkey()
601             p._encprivkey = fn.get_encprivkey()
602             p._pubkey = fn.get_pubkey()
603             return p._generate_shares(shares_and_ids)
604         d.addCallback(_created)
605         def _generated(res):
606             p = self._p
607             final_shares = p.shares
608             root_hash = p.root_hash
609             self.failUnlessEqual(len(root_hash), 32)
610             self.failUnless(isinstance(final_shares, dict))
611             self.failUnlessEqual(len(final_shares), 10)
612             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
613             for i,sh in final_shares.items():
614                 self.failUnless(isinstance(sh, str))
615                 # feed the share through the unpacker as a sanity-check
616                 pieces = unpack_share(sh)
617                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
618                  pubkey, signature, share_hash_chain, block_hash_tree,
619                  share_data, enc_privkey) = pieces
620                 self.failUnlessEqual(u_seqnum, 3)
621                 self.failUnlessEqual(u_root_hash, root_hash)
622                 self.failUnlessEqual(k, 3)
623                 self.failUnlessEqual(N, 10)
624                 self.failUnlessEqual(segsize, 21)
625                 self.failUnlessEqual(datalen, len(CONTENTS))
626                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
627                 sig_material = struct.pack(">BQ32s16s BBQQ",
628                                            0, p._new_seqnum, root_hash, IV,
629                                            k, N, segsize, datalen)
630                 self.failUnless(p._pubkey.verify(sig_material, signature))
631                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
632                 self.failUnless(isinstance(share_hash_chain, dict))
633                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
634                 for shnum,share_hash in share_hash_chain.items():
635                     self.failUnless(isinstance(shnum, int))
636                     self.failUnless(isinstance(share_hash, str))
637                     self.failUnlessEqual(len(share_hash), 32)
638                 self.failUnless(isinstance(block_hash_tree, list))
639                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
640                 self.failUnlessEqual(IV, "SALT"*4)
641                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
642                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
643         d.addCallback(_generated)
644         return d
645
646     # TODO: when we publish to 20 peers, we should get one share per peer on 10
647     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
648     # when we publish to zero peers, we should get a NotEnoughSharesError
649
650 class PublishMixin:
651     def publish_one(self):
652         # publish a file and create shares, which can then be manipulated
653         # later.
654         self.CONTENTS = "New contents go here" * 1000
655         num_peers = 20
656         self._client = FakeClient(num_peers)
657         self._storage = self._client._storage
658         d = self._client.create_mutable_file(self.CONTENTS)
659         def _created(node):
660             self._fn = node
661             self._fn2 = self._client.create_node_from_uri(node.get_uri())
662         d.addCallback(_created)
663         return d
664     def publish_multiple(self):
665         self.CONTENTS = ["Contents 0",
666                          "Contents 1",
667                          "Contents 2",
668                          "Contents 3a",
669                          "Contents 3b"]
670         self._copied_shares = {}
671         num_peers = 20
672         self._client = FakeClient(num_peers)
673         self._storage = self._client._storage
674         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
675         def _created(node):
676             self._fn = node
677             # now create multiple versions of the same file, and accumulate
678             # their shares, so we can mix and match them later.
679             d = defer.succeed(None)
680             d.addCallback(self._copy_shares, 0)
681             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
682             d.addCallback(self._copy_shares, 1)
683             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
684             d.addCallback(self._copy_shares, 2)
685             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
686             d.addCallback(self._copy_shares, 3)
687             # now we replace all the shares with version s3, and upload a new
688             # version to get s4b.
689             rollback = dict([(i,2) for i in range(10)])
690             d.addCallback(lambda res: self._set_versions(rollback))
691             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
692             d.addCallback(self._copy_shares, 4)
693             # we leave the storage in state 4
694             return d
695         d.addCallback(_created)
696         return d
697
698     def _copy_shares(self, ignored, index):
699         shares = self._client._storage._peers
700         # we need a deep copy
701         new_shares = {}
702         for peerid in shares:
703             new_shares[peerid] = {}
704             for shnum in shares[peerid]:
705                 new_shares[peerid][shnum] = shares[peerid][shnum]
706         self._copied_shares[index] = new_shares
707
708     def _set_versions(self, versionmap):
709         # versionmap maps shnums to which version (0,1,2,3,4) we want the
710         # share to be at. Any shnum which is left out of the map will stay at
711         # its current version.
712         shares = self._client._storage._peers
713         oldshares = self._copied_shares
714         for peerid in shares:
715             for shnum in shares[peerid]:
716                 if shnum in versionmap:
717                     index = versionmap[shnum]
718                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
719
720
721 class Servermap(unittest.TestCase, PublishMixin):
722     def setUp(self):
723         return self.publish_one()
724
725     def make_servermap(self, mode=MODE_CHECK, fn=None):
726         if fn is None:
727             fn = self._fn
728         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
729         d = smu.update()
730         return d
731
732     def update_servermap(self, oldmap, mode=MODE_CHECK):
733         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
734         d = smu.update()
735         return d
736
737     def failUnlessOneRecoverable(self, sm, num_shares):
738         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
739         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
740         best = sm.best_recoverable_version()
741         self.failIfEqual(best, None)
742         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
743         self.failUnlessEqual(len(sm.shares_available()), 1)
744         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
745         shnum, peerids = sm.make_sharemap().items()[0]
746         peerid = list(peerids)[0]
747         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
748         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
749         return sm
750
751     def test_basic(self):
752         d = defer.succeed(None)
753         ms = self.make_servermap
754         us = self.update_servermap
755
756         d.addCallback(lambda res: ms(mode=MODE_CHECK))
757         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
758         d.addCallback(lambda res: ms(mode=MODE_WRITE))
759         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
760         d.addCallback(lambda res: ms(mode=MODE_READ))
761         # this more stops at k+epsilon, and epsilon=k, so 6 shares
762         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
763         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
764         # this mode stops at 'k' shares
765         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
766
767         # and can we re-use the same servermap? Note that these are sorted in
768         # increasing order of number of servers queried, since once a server
769         # gets into the servermap, we'll always ask it for an update.
770         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
771         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
772         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
773         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
774         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
775         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
776         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
777         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
778         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
779
780         return d
781
782     def test_fetch_privkey(self):
783         d = defer.succeed(None)
784         # use the sibling filenode (which hasn't been used yet), and make
785         # sure it can fetch the privkey. The file is small, so the privkey
786         # will be fetched on the first (query) pass.
787         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
788         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
789
790         # create a new file, which is large enough to knock the privkey out
791         # of the early part of the file
792         LARGE = "These are Larger contents" * 200 # about 5KB
793         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
794         def _created(large_fn):
795             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
796             return self.make_servermap(MODE_WRITE, large_fn2)
797         d.addCallback(_created)
798         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
799         return d
800
801     def test_mark_bad(self):
802         d = defer.succeed(None)
803         ms = self.make_servermap
804         us = self.update_servermap
805
806         d.addCallback(lambda res: ms(mode=MODE_READ))
807         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
808         def _made_map(sm):
809             v = sm.best_recoverable_version()
810             vm = sm.make_versionmap()
811             shares = list(vm[v])
812             self.failUnlessEqual(len(shares), 6)
813             self._corrupted = set()
814             # mark the first 5 shares as corrupt, then update the servermap.
815             # The map should not have the marked shares it in any more, and
816             # new shares should be found to replace the missing ones.
817             for (shnum, peerid, timestamp) in shares:
818                 if shnum < 5:
819                     self._corrupted.add( (peerid, shnum) )
820                     sm.mark_bad_share(peerid, shnum, "")
821             return self.update_servermap(sm, MODE_WRITE)
822         d.addCallback(_made_map)
823         def _check_map(sm):
824             # this should find all 5 shares that weren't marked bad
825             v = sm.best_recoverable_version()
826             vm = sm.make_versionmap()
827             shares = list(vm[v])
828             for (peerid, shnum) in self._corrupted:
829                 peer_shares = sm.shares_on_peer(peerid)
830                 self.failIf(shnum in peer_shares,
831                             "%d was in %s" % (shnum, peer_shares))
832             self.failUnlessEqual(len(shares), 5)
833         d.addCallback(_check_map)
834         return d
835
836     def failUnlessNoneRecoverable(self, sm):
837         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
838         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
839         best = sm.best_recoverable_version()
840         self.failUnlessEqual(best, None)
841         self.failUnlessEqual(len(sm.shares_available()), 0)
842
843     def test_no_shares(self):
844         self._client._storage._peers = {} # delete all shares
845         ms = self.make_servermap
846         d = defer.succeed(None)
847
848         d.addCallback(lambda res: ms(mode=MODE_CHECK))
849         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
850
851         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
852         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
853
854         d.addCallback(lambda res: ms(mode=MODE_WRITE))
855         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
856
857         d.addCallback(lambda res: ms(mode=MODE_READ))
858         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
859
860         return d
861
862     def failUnlessNotQuiteEnough(self, sm):
863         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
864         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
865         best = sm.best_recoverable_version()
866         self.failUnlessEqual(best, None)
867         self.failUnlessEqual(len(sm.shares_available()), 1)
868         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
869         return sm
870
871     def test_not_quite_enough_shares(self):
872         s = self._client._storage
873         ms = self.make_servermap
874         num_shares = len(s._peers)
875         for peerid in s._peers:
876             s._peers[peerid] = {}
877             num_shares -= 1
878             if num_shares == 2:
879                 break
880         # now there ought to be only two shares left
881         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
882
883         d = defer.succeed(None)
884
885         d.addCallback(lambda res: ms(mode=MODE_CHECK))
886         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
887         d.addCallback(lambda sm:
888                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
889         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
890         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
891         d.addCallback(lambda res: ms(mode=MODE_WRITE))
892         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
893         d.addCallback(lambda res: ms(mode=MODE_READ))
894         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
895
896         return d
897
898
899
900 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
901     def setUp(self):
902         return self.publish_one()
903
904     def make_servermap(self, mode=MODE_READ, oldmap=None):
905         if oldmap is None:
906             oldmap = ServerMap()
907         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
908         d = smu.update()
909         return d
910
911     def abbrev_verinfo(self, verinfo):
912         if verinfo is None:
913             return None
914         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
915          offsets_tuple) = verinfo
916         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
917
918     def abbrev_verinfo_dict(self, verinfo_d):
919         output = {}
920         for verinfo,value in verinfo_d.items():
921             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
922              offsets_tuple) = verinfo
923             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
924         return output
925
926     def dump_servermap(self, servermap):
927         print "SERVERMAP", servermap
928         print "RECOVERABLE", [self.abbrev_verinfo(v)
929                               for v in servermap.recoverable_versions()]
930         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
931         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
932
933     def do_download(self, servermap, version=None):
934         if version is None:
935             version = servermap.best_recoverable_version()
936         r = Retrieve(self._fn, servermap, version)
937         return r.download()
938
939     def test_basic(self):
940         d = self.make_servermap()
941         def _do_retrieve(servermap):
942             self._smap = servermap
943             #self.dump_servermap(servermap)
944             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
945             return self.do_download(servermap)
946         d.addCallback(_do_retrieve)
947         def _retrieved(new_contents):
948             self.failUnlessEqual(new_contents, self.CONTENTS)
949         d.addCallback(_retrieved)
950         # we should be able to re-use the same servermap, both with and
951         # without updating it.
952         d.addCallback(lambda res: self.do_download(self._smap))
953         d.addCallback(_retrieved)
954         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
955         d.addCallback(lambda res: self.do_download(self._smap))
956         d.addCallback(_retrieved)
957         # clobbering the pubkey should make the servermap updater re-fetch it
958         def _clobber_pubkey(res):
959             self._fn._pubkey = None
960         d.addCallback(_clobber_pubkey)
961         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
962         d.addCallback(lambda res: self.do_download(self._smap))
963         d.addCallback(_retrieved)
964         return d
965
966     def test_all_shares_vanished(self):
967         d = self.make_servermap()
968         def _remove_shares(servermap):
969             for shares in self._storage._peers.values():
970                 shares.clear()
971             d1 = self.shouldFail(NotEnoughSharesError,
972                                  "test_all_shares_vanished",
973                                  "ran out of peers",
974                                  self.do_download, servermap)
975             return d1
976         d.addCallback(_remove_shares)
977         return d
978
979     def test_no_servers(self):
980         c2 = FakeClient(0)
981         self._fn._client = c2
982         # if there are no servers, then a MODE_READ servermap should come
983         # back empty
984         d = self.make_servermap()
985         def _check_servermap(servermap):
986             self.failUnlessEqual(servermap.best_recoverable_version(), None)
987             self.failIf(servermap.recoverable_versions())
988             self.failIf(servermap.unrecoverable_versions())
989             self.failIf(servermap.all_peers())
990         d.addCallback(_check_servermap)
991         return d
992     test_no_servers.timeout = 15
993
994     def test_no_servers_download(self):
995         c2 = FakeClient(0)
996         self._fn._client = c2
997         d = self.shouldFail(UnrecoverableFileError,
998                             "test_no_servers_download",
999                             "no recoverable versions",
1000                             self._fn.download_best_version)
1001         def _restore(res):
1002             # a failed download that occurs while we aren't connected to
1003             # anybody should not prevent a subsequent download from working.
1004             # This isn't quite the webapi-driven test that #463 wants, but it
1005             # should be close enough.
1006             self._fn._client = self._client
1007             return self._fn.download_best_version()
1008         def _retrieved(new_contents):
1009             self.failUnlessEqual(new_contents, self.CONTENTS)
1010         d.addCallback(_restore)
1011         d.addCallback(_retrieved)
1012         return d
1013     test_no_servers_download.timeout = 15
1014
1015     def _test_corrupt_all(self, offset, substring,
1016                           should_succeed=False, corrupt_early=True,
1017                           failure_checker=None):
1018         d = defer.succeed(None)
1019         if corrupt_early:
1020             d.addCallback(corrupt, self._storage, offset)
1021         d.addCallback(lambda res: self.make_servermap())
1022         if not corrupt_early:
1023             d.addCallback(corrupt, self._storage, offset)
1024         def _do_retrieve(servermap):
1025             ver = servermap.best_recoverable_version()
1026             if ver is None and not should_succeed:
1027                 # no recoverable versions == not succeeding. The problem
1028                 # should be noted in the servermap's list of problems.
1029                 if substring:
1030                     allproblems = [str(f) for f in servermap.problems]
1031                     self.failUnless(substring in "".join(allproblems))
1032                 return servermap
1033             if should_succeed:
1034                 d1 = self._fn.download_version(servermap, ver)
1035                 d1.addCallback(lambda new_contents:
1036                                self.failUnlessEqual(new_contents, self.CONTENTS))
1037             else:
1038                 d1 = self.shouldFail(NotEnoughSharesError,
1039                                      "_corrupt_all(offset=%s)" % (offset,),
1040                                      substring,
1041                                      self._fn.download_version, servermap, ver)
1042             if failure_checker:
1043                 d1.addCallback(failure_checker)
1044             d1.addCallback(lambda res: servermap)
1045             return d1
1046         d.addCallback(_do_retrieve)
1047         return d
1048
1049     def test_corrupt_all_verbyte(self):
1050         # when the version byte is not 0, we hit an assertion error in
1051         # unpack_share().
1052         d = self._test_corrupt_all(0, "AssertionError")
1053         def _check_servermap(servermap):
1054             # and the dump should mention the problems
1055             s = StringIO()
1056             dump = servermap.dump(s).getvalue()
1057             self.failUnless("10 PROBLEMS" in dump, dump)
1058         d.addCallback(_check_servermap)
1059         return d
1060
1061     def test_corrupt_all_seqnum(self):
1062         # a corrupt sequence number will trigger a bad signature
1063         return self._test_corrupt_all(1, "signature is invalid")
1064
1065     def test_corrupt_all_R(self):
1066         # a corrupt root hash will trigger a bad signature
1067         return self._test_corrupt_all(9, "signature is invalid")
1068
1069     def test_corrupt_all_IV(self):
1070         # a corrupt salt/IV will trigger a bad signature
1071         return self._test_corrupt_all(41, "signature is invalid")
1072
1073     def test_corrupt_all_k(self):
1074         # a corrupt 'k' will trigger a bad signature
1075         return self._test_corrupt_all(57, "signature is invalid")
1076
1077     def test_corrupt_all_N(self):
1078         # a corrupt 'N' will trigger a bad signature
1079         return self._test_corrupt_all(58, "signature is invalid")
1080
1081     def test_corrupt_all_segsize(self):
1082         # a corrupt segsize will trigger a bad signature
1083         return self._test_corrupt_all(59, "signature is invalid")
1084
1085     def test_corrupt_all_datalen(self):
1086         # a corrupt data length will trigger a bad signature
1087         return self._test_corrupt_all(67, "signature is invalid")
1088
1089     def test_corrupt_all_pubkey(self):
1090         # a corrupt pubkey won't match the URI's fingerprint. We need to
1091         # remove the pubkey from the filenode, or else it won't bother trying
1092         # to update it.
1093         self._fn._pubkey = None
1094         return self._test_corrupt_all("pubkey",
1095                                       "pubkey doesn't match fingerprint")
1096
1097     def test_corrupt_all_sig(self):
1098         # a corrupt signature is a bad one
1099         # the signature runs from about [543:799], depending upon the length
1100         # of the pubkey
1101         return self._test_corrupt_all("signature", "signature is invalid")
1102
1103     def test_corrupt_all_share_hash_chain_number(self):
1104         # a corrupt share hash chain entry will show up as a bad hash. If we
1105         # mangle the first byte, that will look like a bad hash number,
1106         # causing an IndexError
1107         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1108
1109     def test_corrupt_all_share_hash_chain_hash(self):
1110         # a corrupt share hash chain entry will show up as a bad hash. If we
1111         # mangle a few bytes in, that will look like a bad hash.
1112         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1113
1114     def test_corrupt_all_block_hash_tree(self):
1115         return self._test_corrupt_all("block_hash_tree",
1116                                       "block hash tree failure")
1117
1118     def test_corrupt_all_block(self):
1119         return self._test_corrupt_all("share_data", "block hash tree failure")
1120
1121     def test_corrupt_all_encprivkey(self):
1122         # a corrupted privkey won't even be noticed by the reader, only by a
1123         # writer.
1124         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1125
1126
1127     def test_corrupt_all_seqnum_late(self):
1128         # corrupting the seqnum between mapupdate and retrieve should result
1129         # in NotEnoughSharesError, since each share will look invalid
1130         def _check(res):
1131             f = res[0]
1132             self.failUnless(f.check(NotEnoughSharesError))
1133             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1134         return self._test_corrupt_all(1, "ran out of peers",
1135                                       corrupt_early=False,
1136                                       failure_checker=_check)
1137
1138     def test_corrupt_all_block_hash_tree_late(self):
1139         def _check(res):
1140             f = res[0]
1141             self.failUnless(f.check(NotEnoughSharesError))
1142         return self._test_corrupt_all("block_hash_tree",
1143                                       "block hash tree failure",
1144                                       corrupt_early=False,
1145                                       failure_checker=_check)
1146
1147
1148     def test_corrupt_all_block_late(self):
1149         def _check(res):
1150             f = res[0]
1151             self.failUnless(f.check(NotEnoughSharesError))
1152         return self._test_corrupt_all("share_data", "block hash tree failure",
1153                                       corrupt_early=False,
1154                                       failure_checker=_check)
1155
1156
1157     def test_basic_pubkey_at_end(self):
1158         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1159         # download to succeed but forcing a bunch of retries first. Note that
1160         # this is rather pessimistic: our Retrieve process will throw away
1161         # the whole share if the pubkey is bad, even though the rest of the
1162         # share might be good.
1163
1164         self._fn._pubkey = None
1165         k = self._fn.get_required_shares()
1166         N = self._fn.get_total_shares()
1167         d = defer.succeed(None)
1168         d.addCallback(corrupt, self._storage, "pubkey",
1169                       shnums_to_corrupt=range(0, N-k))
1170         d.addCallback(lambda res: self.make_servermap())
1171         def _do_retrieve(servermap):
1172             self.failUnless(servermap.problems)
1173             self.failUnless("pubkey doesn't match fingerprint"
1174                             in str(servermap.problems[0]))
1175             ver = servermap.best_recoverable_version()
1176             r = Retrieve(self._fn, servermap, ver)
1177             return r.download()
1178         d.addCallback(_do_retrieve)
1179         d.addCallback(lambda new_contents:
1180                       self.failUnlessEqual(new_contents, self.CONTENTS))
1181         return d
1182
1183     def test_corrupt_some(self):
1184         # corrupt the data of first five shares (so the servermap thinks
1185         # they're good but retrieve marks them as bad), so that the
1186         # MODE_READ set of 6 will be insufficient, forcing node.download to
1187         # retry with more servers.
1188         corrupt(None, self._storage, "share_data", range(5))
1189         d = self.make_servermap()
1190         def _do_retrieve(servermap):
1191             ver = servermap.best_recoverable_version()
1192             self.failUnless(ver)
1193             return self._fn.download_best_version()
1194         d.addCallback(_do_retrieve)
1195         d.addCallback(lambda new_contents:
1196                       self.failUnlessEqual(new_contents, self.CONTENTS))
1197         return d
1198
1199     def test_download_fails(self):
1200         corrupt(None, self._storage, "signature")
1201         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1202                             "no recoverable versions",
1203                             self._fn.download_best_version)
1204         return d
1205
1206
1207 class CheckerMixin:
1208     def check_good(self, r, where):
1209         self.failUnless(r.is_healthy(), where)
1210         return r
1211
1212     def check_bad(self, r, where):
1213         self.failIf(r.is_healthy(), where)
1214         return r
1215
1216     def check_expected_failure(self, r, expected_exception, substring, where):
1217         for (peerid, storage_index, shnum, f) in r.problems:
1218             if f.check(expected_exception):
1219                 self.failUnless(substring in str(f),
1220                                 "%s: substring '%s' not in '%s'" %
1221                                 (where, substring, str(f)))
1222                 return
1223         self.fail("%s: didn't see expected exception %s in problems %s" %
1224                   (where, expected_exception, r.problems))
1225
1226
1227 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1228     def setUp(self):
1229         return self.publish_one()
1230
1231
1232     def test_check_good(self):
1233         d = self._fn.check(Monitor())
1234         d.addCallback(self.check_good, "test_check_good")
1235         return d
1236
1237     def test_check_no_shares(self):
1238         for shares in self._storage._peers.values():
1239             shares.clear()
1240         d = self._fn.check(Monitor())
1241         d.addCallback(self.check_bad, "test_check_no_shares")
1242         return d
1243
1244     def test_check_not_enough_shares(self):
1245         for shares in self._storage._peers.values():
1246             for shnum in shares.keys():
1247                 if shnum > 0:
1248                     del shares[shnum]
1249         d = self._fn.check(Monitor())
1250         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1251         return d
1252
1253     def test_check_all_bad_sig(self):
1254         corrupt(None, self._storage, 1) # bad sig
1255         d = self._fn.check(Monitor())
1256         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1257         return d
1258
1259     def test_check_all_bad_blocks(self):
1260         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1261         # the Checker won't notice this.. it doesn't look at actual data
1262         d = self._fn.check(Monitor())
1263         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1264         return d
1265
1266     def test_verify_good(self):
1267         d = self._fn.check(Monitor(), verify=True)
1268         d.addCallback(self.check_good, "test_verify_good")
1269         return d
1270
1271     def test_verify_all_bad_sig(self):
1272         corrupt(None, self._storage, 1) # bad sig
1273         d = self._fn.check(Monitor(), verify=True)
1274         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1275         return d
1276
1277     def test_verify_one_bad_sig(self):
1278         corrupt(None, self._storage, 1, [9]) # bad sig
1279         d = self._fn.check(Monitor(), verify=True)
1280         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1281         return d
1282
1283     def test_verify_one_bad_block(self):
1284         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1285         # the Verifier *will* notice this, since it examines every byte
1286         d = self._fn.check(Monitor(), verify=True)
1287         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1288         d.addCallback(self.check_expected_failure,
1289                       CorruptShareError, "block hash tree failure",
1290                       "test_verify_one_bad_block")
1291         return d
1292
1293     def test_verify_one_bad_sharehash(self):
1294         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1295         d = self._fn.check(Monitor(), verify=True)
1296         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1297         d.addCallback(self.check_expected_failure,
1298                       CorruptShareError, "corrupt hashes",
1299                       "test_verify_one_bad_sharehash")
1300         return d
1301
1302     def test_verify_one_bad_encprivkey(self):
1303         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1304         d = self._fn.check(Monitor(), verify=True)
1305         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1306         d.addCallback(self.check_expected_failure,
1307                       CorruptShareError, "invalid privkey",
1308                       "test_verify_one_bad_encprivkey")
1309         return d
1310
1311     def test_verify_one_bad_encprivkey_uncheckable(self):
1312         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1313         readonly_fn = self._fn.get_readonly()
1314         # a read-only node has no way to validate the privkey
1315         d = readonly_fn.check(Monitor(), verify=True)
1316         d.addCallback(self.check_good,
1317                       "test_verify_one_bad_encprivkey_uncheckable")
1318         return d
1319
1320 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1321
1322     def get_shares(self, s):
1323         all_shares = {} # maps (peerid, shnum) to share data
1324         for peerid in s._peers:
1325             shares = s._peers[peerid]
1326             for shnum in shares:
1327                 data = shares[shnum]
1328                 all_shares[ (peerid, shnum) ] = data
1329         return all_shares
1330
1331     def copy_shares(self, ignored=None):
1332         self.old_shares.append(self.get_shares(self._storage))
1333
1334     def test_repair_nop(self):
1335         self.old_shares = []
1336         d = self.publish_one()
1337         d.addCallback(self.copy_shares)
1338         d.addCallback(lambda res: self._fn.check(Monitor()))
1339         d.addCallback(lambda check_results: self._fn.repair(check_results))
1340         def _check_results(rres):
1341             self.failUnless(IRepairResults.providedBy(rres))
1342             # TODO: examine results
1343
1344             self.copy_shares()
1345
1346             initial_shares = self.old_shares[0]
1347             new_shares = self.old_shares[1]
1348             # TODO: this really shouldn't change anything. When we implement
1349             # a "minimal-bandwidth" repairer", change this test to assert:
1350             #self.failUnlessEqual(new_shares, initial_shares)
1351
1352             # all shares should be in the same place as before
1353             self.failUnlessEqual(set(initial_shares.keys()),
1354                                  set(new_shares.keys()))
1355             # but they should all be at a newer seqnum. The IV will be
1356             # different, so the roothash will be too.
1357             for key in initial_shares:
1358                 (version0,
1359                  seqnum0,
1360                  root_hash0,
1361                  IV0,
1362                  k0, N0, segsize0, datalen0,
1363                  o0) = unpack_header(initial_shares[key])
1364                 (version1,
1365                  seqnum1,
1366                  root_hash1,
1367                  IV1,
1368                  k1, N1, segsize1, datalen1,
1369                  o1) = unpack_header(new_shares[key])
1370                 self.failUnlessEqual(version0, version1)
1371                 self.failUnlessEqual(seqnum0+1, seqnum1)
1372                 self.failUnlessEqual(k0, k1)
1373                 self.failUnlessEqual(N0, N1)
1374                 self.failUnlessEqual(segsize0, segsize1)
1375                 self.failUnlessEqual(datalen0, datalen1)
1376         d.addCallback(_check_results)
1377         return d
1378
1379     def failIfSharesChanged(self, ignored=None):
1380         old_shares = self.old_shares[-2]
1381         current_shares = self.old_shares[-1]
1382         self.failUnlessEqual(old_shares, current_shares)
1383
1384     def test_merge(self):
1385         self.old_shares = []
1386         d = self.publish_multiple()
1387         # repair will refuse to merge multiple highest seqnums unless you
1388         # pass force=True
1389         d.addCallback(lambda res:
1390                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1391                                           1:4,3:4,5:4,7:4,9:4}))
1392         d.addCallback(self.copy_shares)
1393         d.addCallback(lambda res: self._fn.check(Monitor()))
1394         def _try_repair(check_results):
1395             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1396             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1397                                  self._fn.repair, check_results)
1398             d2.addCallback(self.copy_shares)
1399             d2.addCallback(self.failIfSharesChanged)
1400             d2.addCallback(lambda res: check_results)
1401             return d2
1402         d.addCallback(_try_repair)
1403         d.addCallback(lambda check_results:
1404                       self._fn.repair(check_results, force=True))
1405         # this should give us 10 shares of the highest roothash
1406         def _check_repair_results(rres):
1407             pass # TODO
1408         d.addCallback(_check_repair_results)
1409         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1410         def _check_smap(smap):
1411             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1412             self.failIf(smap.unrecoverable_versions())
1413             # now, which should have won?
1414             roothash_s4a = self.get_roothash_for(3)
1415             roothash_s4b = self.get_roothash_for(4)
1416             if roothash_s4b > roothash_s4a:
1417                 expected_contents = self.CONTENTS[4]
1418             else:
1419                 expected_contents = self.CONTENTS[3]
1420             new_versionid = smap.best_recoverable_version()
1421             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1422             d2 = self._fn.download_version(smap, new_versionid)
1423             d2.addCallback(self.failUnlessEqual, expected_contents)
1424             return d2
1425         d.addCallback(_check_smap)
1426         return d
1427
1428     def test_non_merge(self):
1429         self.old_shares = []
1430         d = self.publish_multiple()
1431         # repair should not refuse a repair that doesn't need to merge. In
1432         # this case, we combine v2 with v3. The repair should ignore v2 and
1433         # copy v3 into a new v5.
1434         d.addCallback(lambda res:
1435                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1436                                           1:3,3:3,5:3,7:3,9:3}))
1437         d.addCallback(lambda res: self._fn.check(Monitor()))
1438         d.addCallback(lambda check_results: self._fn.repair(check_results))
1439         # this should give us 10 shares of v3
1440         def _check_repair_results(rres):
1441             pass # TODO
1442         d.addCallback(_check_repair_results)
1443         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1444         def _check_smap(smap):
1445             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1446             self.failIf(smap.unrecoverable_versions())
1447             # now, which should have won?
1448             roothash_s4a = self.get_roothash_for(3)
1449             expected_contents = self.CONTENTS[3]
1450             new_versionid = smap.best_recoverable_version()
1451             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1452             d2 = self._fn.download_version(smap, new_versionid)
1453             d2.addCallback(self.failUnlessEqual, expected_contents)
1454             return d2
1455         d.addCallback(_check_smap)
1456         return d
1457
1458     def get_roothash_for(self, index):
1459         # return the roothash for the first share we see in the saved set
1460         shares = self._copied_shares[index]
1461         for peerid in shares:
1462             for shnum in shares[peerid]:
1463                 share = shares[peerid][shnum]
1464                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1465                           unpack_header(share)
1466                 return root_hash
1467
1468 class MultipleEncodings(unittest.TestCase):
1469     def setUp(self):
1470         self.CONTENTS = "New contents go here"
1471         num_peers = 20
1472         self._client = FakeClient(num_peers)
1473         self._storage = self._client._storage
1474         d = self._client.create_mutable_file(self.CONTENTS)
1475         def _created(node):
1476             self._fn = node
1477         d.addCallback(_created)
1478         return d
1479
1480     def _encode(self, k, n, data):
1481         # encode 'data' into a peerid->shares dict.
1482
1483         fn2 = FastMutableFileNode(self._client)
1484         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1485         # and _fingerprint
1486         fn = self._fn
1487         fn2.init_from_uri(fn.get_uri())
1488         # then we copy over other fields that are normally fetched from the
1489         # existing shares
1490         fn2._pubkey = fn._pubkey
1491         fn2._privkey = fn._privkey
1492         fn2._encprivkey = fn._encprivkey
1493         # and set the encoding parameters to something completely different
1494         fn2._required_shares = k
1495         fn2._total_shares = n
1496
1497         s = self._client._storage
1498         s._peers = {} # clear existing storage
1499         p2 = Publish(fn2, None)
1500         d = p2.publish(data)
1501         def _published(res):
1502             shares = s._peers
1503             s._peers = {}
1504             return shares
1505         d.addCallback(_published)
1506         return d
1507
1508     def make_servermap(self, mode=MODE_READ, oldmap=None):
1509         if oldmap is None:
1510             oldmap = ServerMap()
1511         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1512         d = smu.update()
1513         return d
1514
1515     def test_multiple_encodings(self):
1516         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1517         # then mix up the shares, to make sure that download survives seeing
1518         # a variety of encodings. This is actually kind of tricky to set up.
1519
1520         contents1 = "Contents for encoding 1 (3-of-10) go here"
1521         contents2 = "Contents for encoding 2 (4-of-9) go here"
1522         contents3 = "Contents for encoding 3 (4-of-7) go here"
1523
1524         # we make a retrieval object that doesn't know what encoding
1525         # parameters to use
1526         fn3 = FastMutableFileNode(self._client)
1527         fn3.init_from_uri(self._fn.get_uri())
1528
1529         # now we upload a file through fn1, and grab its shares
1530         d = self._encode(3, 10, contents1)
1531         def _encoded_1(shares):
1532             self._shares1 = shares
1533         d.addCallback(_encoded_1)
1534         d.addCallback(lambda res: self._encode(4, 9, contents2))
1535         def _encoded_2(shares):
1536             self._shares2 = shares
1537         d.addCallback(_encoded_2)
1538         d.addCallback(lambda res: self._encode(4, 7, contents3))
1539         def _encoded_3(shares):
1540             self._shares3 = shares
1541         d.addCallback(_encoded_3)
1542
1543         def _merge(res):
1544             log.msg("merging sharelists")
1545             # we merge the shares from the two sets, leaving each shnum in
1546             # its original location, but using a share from set1 or set2
1547             # according to the following sequence:
1548             #
1549             #  4-of-9  a  s2
1550             #  4-of-9  b  s2
1551             #  4-of-7  c   s3
1552             #  4-of-9  d  s2
1553             #  3-of-9  e s1
1554             #  3-of-9  f s1
1555             #  3-of-9  g s1
1556             #  4-of-9  h  s2
1557             #
1558             # so that neither form can be recovered until fetch [f], at which
1559             # point version-s1 (the 3-of-10 form) should be recoverable. If
1560             # the implementation latches on to the first version it sees,
1561             # then s2 will be recoverable at fetch [g].
1562
1563             # Later, when we implement code that handles multiple versions,
1564             # we can use this framework to assert that all recoverable
1565             # versions are retrieved, and test that 'epsilon' does its job
1566
1567             places = [2, 2, 3, 2, 1, 1, 1, 2]
1568
1569             sharemap = {}
1570             sb = self._client.get_storage_broker()
1571
1572             for peerid in sorted(sb.get_all_serverids()):
1573                 peerid_s = shortnodeid_b2a(peerid)
1574                 for shnum in self._shares1.get(peerid, {}):
1575                     if shnum < len(places):
1576                         which = places[shnum]
1577                     else:
1578                         which = "x"
1579                     self._client._storage._peers[peerid] = peers = {}
1580                     in_1 = shnum in self._shares1[peerid]
1581                     in_2 = shnum in self._shares2.get(peerid, {})
1582                     in_3 = shnum in self._shares3.get(peerid, {})
1583                     #print peerid_s, shnum, which, in_1, in_2, in_3
1584                     if which == 1:
1585                         if in_1:
1586                             peers[shnum] = self._shares1[peerid][shnum]
1587                             sharemap[shnum] = peerid
1588                     elif which == 2:
1589                         if in_2:
1590                             peers[shnum] = self._shares2[peerid][shnum]
1591                             sharemap[shnum] = peerid
1592                     elif which == 3:
1593                         if in_3:
1594                             peers[shnum] = self._shares3[peerid][shnum]
1595                             sharemap[shnum] = peerid
1596
1597             # we don't bother placing any other shares
1598             # now sort the sequence so that share 0 is returned first
1599             new_sequence = [sharemap[shnum]
1600                             for shnum in sorted(sharemap.keys())]
1601             self._client._storage._sequence = new_sequence
1602             log.msg("merge done")
1603         d.addCallback(_merge)
1604         d.addCallback(lambda res: fn3.download_best_version())
1605         def _retrieved(new_contents):
1606             # the current specified behavior is "first version recoverable"
1607             self.failUnlessEqual(new_contents, contents1)
1608         d.addCallback(_retrieved)
1609         return d
1610
1611
1612 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1613
1614     def setUp(self):
1615         return self.publish_multiple()
1616
1617     def test_multiple_versions(self):
1618         # if we see a mix of versions in the grid, download_best_version
1619         # should get the latest one
1620         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1621         d = self._fn.download_best_version()
1622         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1623         # and the checker should report problems
1624         d.addCallback(lambda res: self._fn.check(Monitor()))
1625         d.addCallback(self.check_bad, "test_multiple_versions")
1626
1627         # but if everything is at version 2, that's what we should download
1628         d.addCallback(lambda res:
1629                       self._set_versions(dict([(i,2) for i in range(10)])))
1630         d.addCallback(lambda res: self._fn.download_best_version())
1631         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1632         # if exactly one share is at version 3, we should still get v2
1633         d.addCallback(lambda res:
1634                       self._set_versions({0:3}))
1635         d.addCallback(lambda res: self._fn.download_best_version())
1636         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1637         # but the servermap should see the unrecoverable version. This
1638         # depends upon the single newer share being queried early.
1639         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1640         def _check_smap(smap):
1641             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1642             newer = smap.unrecoverable_newer_versions()
1643             self.failUnlessEqual(len(newer), 1)
1644             verinfo, health = newer.items()[0]
1645             self.failUnlessEqual(verinfo[0], 4)
1646             self.failUnlessEqual(health, (1,3))
1647             self.failIf(smap.needs_merge())
1648         d.addCallback(_check_smap)
1649         # if we have a mix of two parallel versions (s4a and s4b), we could
1650         # recover either
1651         d.addCallback(lambda res:
1652                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1653                                           1:4,3:4,5:4,7:4,9:4}))
1654         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1655         def _check_smap_mixed(smap):
1656             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1657             newer = smap.unrecoverable_newer_versions()
1658             self.failUnlessEqual(len(newer), 0)
1659             self.failUnless(smap.needs_merge())
1660         d.addCallback(_check_smap_mixed)
1661         d.addCallback(lambda res: self._fn.download_best_version())
1662         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1663                                                   res == self.CONTENTS[4]))
1664         return d
1665
1666     def test_replace(self):
1667         # if we see a mix of versions in the grid, we should be able to
1668         # replace them all with a newer version
1669
1670         # if exactly one share is at version 3, we should download (and
1671         # replace) v2, and the result should be v4. Note that the index we
1672         # give to _set_versions is different than the sequence number.
1673         target = dict([(i,2) for i in range(10)]) # seqnum3
1674         target[0] = 3 # seqnum4
1675         self._set_versions(target)
1676
1677         def _modify(oldversion, servermap, first_time):
1678             return oldversion + " modified"
1679         d = self._fn.modify(_modify)
1680         d.addCallback(lambda res: self._fn.download_best_version())
1681         expected = self.CONTENTS[2] + " modified"
1682         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1683         # and the servermap should indicate that the outlier was replaced too
1684         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1685         def _check_smap(smap):
1686             self.failUnlessEqual(smap.highest_seqnum(), 5)
1687             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1688             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1689         d.addCallback(_check_smap)
1690         return d
1691
1692
1693 class Utils(unittest.TestCase):
1694     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1695         # we compare this against sets of integers
1696         x = set(range(x_start, x_start+x_length))
1697         y = set(range(y_start, y_start+y_length))
1698         should_be_inside = x.issubset(y)
1699         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1700                                                          y_start, y_length),
1701                              str((x_start, x_length, y_start, y_length)))
1702
1703     def test_cache_inside(self):
1704         c = ResponseCache()
1705         x_start = 10
1706         x_length = 5
1707         for y_start in range(8, 17):
1708             for y_length in range(8):
1709                 self._do_inside(c, x_start, x_length, y_start, y_length)
1710
1711     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1712         # we compare this against sets of integers
1713         x = set(range(x_start, x_start+x_length))
1714         y = set(range(y_start, y_start+y_length))
1715         overlap = bool(x.intersection(y))
1716         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1717                                                       y_start, y_length),
1718                              str((x_start, x_length, y_start, y_length)))
1719
1720     def test_cache_overlap(self):
1721         c = ResponseCache()
1722         x_start = 10
1723         x_length = 5
1724         for y_start in range(8, 17):
1725             for y_length in range(8):
1726                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1727
1728     def test_cache(self):
1729         c = ResponseCache()
1730         # xdata = base62.b2a(os.urandom(100))[:100]
1731         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1732         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1733         nope = (None, None)
1734         c.add("v1", 1, 0, xdata, "time0")
1735         c.add("v1", 1, 2000, ydata, "time1")
1736         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1737         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1738         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1739         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1740         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1741         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1742         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1743         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1744         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1745         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1747         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1748         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1749         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1750         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1752         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1753         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1754
1755         # optional: join fragments
1756         c = ResponseCache()
1757         c.add("v1", 1, 0, xdata[:10], "time0")
1758         c.add("v1", 1, 10, xdata[10:20], "time1")
1759         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1760
1761 class Exceptions(unittest.TestCase):
1762     def test_repr(self):
1763         nmde = NeedMoreDataError(100, 50, 100)
1764         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1765         ucwe = UncoordinatedWriteError()
1766         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1767
1768 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1769 # instances which always succeed. So we need a less-fake one.
1770
1771 class IntentionalError(Exception):
1772     pass
1773
1774 class LocalWrapper:
1775     def __init__(self, original):
1776         self.original = original
1777         self.broken = False
1778         self.post_call_notifier = None
1779     def callRemote(self, methname, *args, **kwargs):
1780         def _call():
1781             if self.broken:
1782                 raise IntentionalError("I was asked to break")
1783             meth = getattr(self.original, "remote_" + methname)
1784             return meth(*args, **kwargs)
1785         d = fireEventually()
1786         d.addCallback(lambda res: _call())
1787         if self.post_call_notifier:
1788             d.addCallback(self.post_call_notifier, methname)
1789         return d
1790
1791 class LessFakeClient(FakeClient):
1792
1793     def __init__(self, basedir, num_peers=10):
1794         self._num_peers = num_peers
1795         peerids = [tagged_hash("peerid", "%d" % i)[:20]
1796                    for i in range(self._num_peers)]
1797         self.storage_broker = StorageFarmBroker(None, True)
1798         for peerid in peerids:
1799             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1800             make_dirs(peerdir)
1801             ss = StorageServer(peerdir, peerid)
1802             lw = LocalWrapper(ss)
1803             self.storage_broker.test_add_server(peerid, lw)
1804         self.nodeid = "fakenodeid"
1805
1806
1807 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1808     def test_publish_surprise(self):
1809         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1810         self.client = LessFakeClient(basedir)
1811         d = self.client.create_mutable_file("contents 1")
1812         def _created(n):
1813             d = defer.succeed(None)
1814             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1815             def _got_smap1(smap):
1816                 # stash the old state of the file
1817                 self.old_map = smap
1818             d.addCallback(_got_smap1)
1819             # then modify the file, leaving the old map untouched
1820             d.addCallback(lambda res: log.msg("starting winning write"))
1821             d.addCallback(lambda res: n.overwrite("contents 2"))
1822             # now attempt to modify the file with the old servermap. This
1823             # will look just like an uncoordinated write, in which every
1824             # single share got updated between our mapupdate and our publish
1825             d.addCallback(lambda res: log.msg("starting doomed write"))
1826             d.addCallback(lambda res:
1827                           self.shouldFail(UncoordinatedWriteError,
1828                                           "test_publish_surprise", None,
1829                                           n.upload,
1830                                           "contents 2a", self.old_map))
1831             return d
1832         d.addCallback(_created)
1833         return d
1834
1835     def test_retrieve_surprise(self):
1836         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1837         self.client = LessFakeClient(basedir)
1838         d = self.client.create_mutable_file("contents 1")
1839         def _created(n):
1840             d = defer.succeed(None)
1841             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1842             def _got_smap1(smap):
1843                 # stash the old state of the file
1844                 self.old_map = smap
1845             d.addCallback(_got_smap1)
1846             # then modify the file, leaving the old map untouched
1847             d.addCallback(lambda res: log.msg("starting winning write"))
1848             d.addCallback(lambda res: n.overwrite("contents 2"))
1849             # now attempt to retrieve the old version with the old servermap.
1850             # This will look like someone has changed the file since we
1851             # updated the servermap.
1852             d.addCallback(lambda res: n._cache._clear())
1853             d.addCallback(lambda res: log.msg("starting doomed read"))
1854             d.addCallback(lambda res:
1855                           self.shouldFail(NotEnoughSharesError,
1856                                           "test_retrieve_surprise",
1857                                           "ran out of peers: have 0 shares (k=3)",
1858                                           n.download_version,
1859                                           self.old_map,
1860                                           self.old_map.best_recoverable_version(),
1861                                           ))
1862             return d
1863         d.addCallback(_created)
1864         return d
1865
1866     def test_unexpected_shares(self):
1867         # upload the file, take a servermap, shut down one of the servers,
1868         # upload it again (causing shares to appear on a new server), then
1869         # upload using the old servermap. The last upload should fail with an
1870         # UncoordinatedWriteError, because of the shares that didn't appear
1871         # in the servermap.
1872         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1873         self.client = LessFakeClient(basedir)
1874         d = self.client.create_mutable_file("contents 1")
1875         def _created(n):
1876             d = defer.succeed(None)
1877             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1878             def _got_smap1(smap):
1879                 # stash the old state of the file
1880                 self.old_map = smap
1881                 # now shut down one of the servers
1882                 peer0 = list(smap.make_sharemap()[0])[0]
1883                 self.client.debug_remove_connection(peer0)
1884                 # then modify the file, leaving the old map untouched
1885                 log.msg("starting winning write")
1886                 return n.overwrite("contents 2")
1887             d.addCallback(_got_smap1)
1888             # now attempt to modify the file with the old servermap. This
1889             # will look just like an uncoordinated write, in which every
1890             # single share got updated between our mapupdate and our publish
1891             d.addCallback(lambda res: log.msg("starting doomed write"))
1892             d.addCallback(lambda res:
1893                           self.shouldFail(UncoordinatedWriteError,
1894                                           "test_surprise", None,
1895                                           n.upload,
1896                                           "contents 2a", self.old_map))
1897             return d
1898         d.addCallback(_created)
1899         return d
1900
1901     def test_bad_server(self):
1902         # Break one server, then create the file: the initial publish should
1903         # complete with an alternate server. Breaking a second server should
1904         # not prevent an update from succeeding either.
1905         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1906         self.client = LessFakeClient(basedir, 20)
1907         # to make sure that one of the initial peers is broken, we have to
1908         # get creative. We create the keys, so we can figure out the storage
1909         # index, but we hold off on doing the initial publish until we've
1910         # broken the server on which the first share wants to be stored.
1911         n = FastMutableFileNode(self.client)
1912         d = defer.succeed(None)
1913         d.addCallback(n._generate_pubprivkeys, keysize=522)
1914         d.addCallback(n._generated)
1915         def _break_peer0(res):
1916             si = n.get_storage_index()
1917             peerlist = self.client.storage_broker.get_servers_for_index(si)
1918             peerid0, connection0 = peerlist[0]
1919             peerid1, connection1 = peerlist[1]
1920             connection0.broken = True
1921             self.connection1 = connection1
1922         d.addCallback(_break_peer0)
1923         # now let the initial publish finally happen
1924         d.addCallback(lambda res: n._upload("contents 1", None))
1925         # that ought to work
1926         d.addCallback(lambda res: n.download_best_version())
1927         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1928         # now break the second peer
1929         def _break_peer1(res):
1930             self.connection1.broken = True
1931         d.addCallback(_break_peer1)
1932         d.addCallback(lambda res: n.overwrite("contents 2"))
1933         # that ought to work too
1934         d.addCallback(lambda res: n.download_best_version())
1935         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1936         def _explain_error(f):
1937             print f
1938             if f.check(NotEnoughServersError):
1939                 print "first_error:", f.value.first_error
1940             return f
1941         d.addErrback(_explain_error)
1942         return d
1943
1944     def test_bad_server_overlap(self):
1945         # like test_bad_server, but with no extra unused servers to fall back
1946         # upon. This means that we must re-use a server which we've already
1947         # used. If we don't remember the fact that we sent them one share
1948         # already, we'll mistakenly think we're experiencing an
1949         # UncoordinatedWriteError.
1950
1951         # Break one server, then create the file: the initial publish should
1952         # complete with an alternate server. Breaking a second server should
1953         # not prevent an update from succeeding either.
1954         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1955         self.client = LessFakeClient(basedir, 10)
1956         sb = self.client.get_storage_broker()
1957
1958         peerids = list(sb.get_all_serverids())
1959         self.client.debug_break_connection(peerids[0])
1960
1961         d = self.client.create_mutable_file("contents 1")
1962         def _created(n):
1963             d = n.download_best_version()
1964             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1965             # now break one of the remaining servers
1966             def _break_second_server(res):
1967                 self.client.debug_break_connection(peerids[1])
1968             d.addCallback(_break_second_server)
1969             d.addCallback(lambda res: n.overwrite("contents 2"))
1970             # that ought to work too
1971             d.addCallback(lambda res: n.download_best_version())
1972             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1973             return d
1974         d.addCallback(_created)
1975         return d
1976
1977     def test_publish_all_servers_bad(self):
1978         # Break all servers: the publish should fail
1979         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1980         self.client = LessFakeClient(basedir, 20)
1981         sb = self.client.get_storage_broker()
1982         for peerid in sb.get_all_serverids():
1983             self.client.debug_break_connection(peerid)
1984         d = self.shouldFail(NotEnoughServersError,
1985                             "test_publish_all_servers_bad",
1986                             "Ran out of non-bad servers",
1987                             self.client.create_mutable_file, "contents")
1988         return d
1989
1990     def test_publish_no_servers(self):
1991         # no servers at all: the publish should fail
1992         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1993         self.client = LessFakeClient(basedir, 0)
1994         d = self.shouldFail(NotEnoughServersError,
1995                             "test_publish_no_servers",
1996                             "Ran out of non-bad servers",
1997                             self.client.create_mutable_file, "contents")
1998         return d
1999     test_publish_no_servers.timeout = 30
2000
2001
2002     def test_privkey_query_error(self):
2003         # when a servermap is updated with MODE_WRITE, it tries to get the
2004         # privkey. Something might go wrong during this query attempt.
2005         self.client = FakeClient(20)
2006         # we need some contents that are large enough to push the privkey out
2007         # of the early part of the file
2008         LARGE = "These are Larger contents" * 200 # about 5KB
2009         d = self.client.create_mutable_file(LARGE)
2010         def _created(n):
2011             self.uri = n.get_uri()
2012             self.n2 = self.client.create_node_from_uri(self.uri)
2013             # we start by doing a map update to figure out which is the first
2014             # server.
2015             return n.get_servermap(MODE_WRITE)
2016         d.addCallback(_created)
2017         d.addCallback(lambda res: fireEventually(res))
2018         def _got_smap1(smap):
2019             peer0 = list(smap.make_sharemap()[0])[0]
2020             # we tell the server to respond to this peer first, so that it
2021             # will be asked for the privkey first
2022             self.client._storage._sequence = [peer0]
2023             # now we make the peer fail their second query
2024             self.client._storage._special_answers[peer0] = ["normal", "fail"]
2025         d.addCallback(_got_smap1)
2026         # now we update a servermap from a new node (which doesn't have the
2027         # privkey yet, forcing it to use a separate privkey query). Each
2028         # query response will trigger a privkey query, and since we're using
2029         # _sequence to make the peer0 response come back first, we'll send it
2030         # a privkey query first, and _sequence will again ensure that the
2031         # peer0 query will also come back before the others, and then
2032         # _special_answers will make sure that the query raises an exception.
2033         # The whole point of these hijinks is to exercise the code in
2034         # _privkey_query_failed. Note that the map-update will succeed, since
2035         # we'll just get a copy from one of the other shares.
2036         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2037         # Using FakeStorage._sequence means there will be read requests still
2038         # floating around.. wait for them to retire
2039         def _cancel_timer(res):
2040             if self.client._storage._pending_timer:
2041                 self.client._storage._pending_timer.cancel()
2042             return res
2043         d.addBoth(_cancel_timer)
2044         return d
2045
2046     def test_privkey_query_missing(self):
2047         # like test_privkey_query_error, but the shares are deleted by the
2048         # second query, instead of raising an exception.
2049         self.client = FakeClient(20)
2050         LARGE = "These are Larger contents" * 200 # about 5KB
2051         d = self.client.create_mutable_file(LARGE)
2052         def _created(n):
2053             self.uri = n.get_uri()
2054             self.n2 = self.client.create_node_from_uri(self.uri)
2055             return n.get_servermap(MODE_WRITE)
2056         d.addCallback(_created)
2057         d.addCallback(lambda res: fireEventually(res))
2058         def _got_smap1(smap):
2059             peer0 = list(smap.make_sharemap()[0])[0]
2060             self.client._storage._sequence = [peer0]
2061             self.client._storage._special_answers[peer0] = ["normal", "none"]
2062         d.addCallback(_got_smap1)
2063         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2064         def _cancel_timer(res):
2065             if self.client._storage._pending_timer:
2066                 self.client._storage._pending_timer.cancel()
2067             return res
2068         d.addBoth(_cancel_timer)
2069         return d