]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
overhaul checker invocation
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, download, storage
8 from allmydata.util import base32, testutil, idlib
9 from allmydata.util.idlib import shortnodeid_b2a
10 from allmydata.util.hashutil import tagged_hash
11 from allmydata.util.fileutil import make_dirs
12 from allmydata.encode import NotEnoughSharesError
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
14      FileTooLargeError
15 from foolscap.eventual import eventually, fireEventually
16 from foolscap.logging import log
17 import sha
18
19 from allmydata.mutable.node import MutableFileNode, BackoffAgent
20 from allmydata.mutable.common import DictOfSets, ResponseCache, \
21      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
22      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
23      NotEnoughServersError, CorruptShareError
24 from allmydata.mutable.retrieve import Retrieve
25 from allmydata.mutable.publish import Publish
26 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
27 from allmydata.mutable.layout import unpack_header, unpack_share
28
29 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
30 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
31 # of this.
32
33 class FastMutableFileNode(MutableFileNode):
34     SIGNATURE_KEY_SIZE = 522
35
36 # this "FakeStorage" exists to put the share data in RAM and avoid using real
37 # network connections, both to speed up the tests and to reduce the amount of
38 # non-mutable.py code being exercised.
39
40 class FakeStorage:
41     # this class replaces the collection of storage servers, allowing the
42     # tests to examine and manipulate the published shares. It also lets us
43     # control the order in which read queries are answered, to exercise more
44     # of the error-handling code in Retrieve .
45     #
46     # Note that we ignore the storage index: this FakeStorage instance can
47     # only be used for a single storage index.
48
49
50     def __init__(self):
51         self._peers = {}
52         # _sequence is used to cause the responses to occur in a specific
53         # order. If it is in use, then we will defer queries instead of
54         # answering them right away, accumulating the Deferreds in a dict. We
55         # don't know exactly how many queries we'll get, so exactly one
56         # second after the first query arrives, we will release them all (in
57         # order).
58         self._sequence = None
59         self._pending = {}
60         self._pending_timer = None
61         self._special_answers = {}
62
63     def read(self, peerid, storage_index):
64         shares = self._peers.get(peerid, {})
65         if self._special_answers.get(peerid, []):
66             mode = self._special_answers[peerid].pop(0)
67             if mode == "fail":
68                 shares = failure.Failure(IntentionalError())
69             elif mode == "none":
70                 shares = {}
71             elif mode == "normal":
72                 pass
73         if self._sequence is None:
74             return defer.succeed(shares)
75         d = defer.Deferred()
76         if not self._pending:
77             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
78         self._pending[peerid] = (d, shares)
79         return d
80
81     def _fire_readers(self):
82         self._pending_timer = None
83         pending = self._pending
84         self._pending = {}
85         extra = []
86         for peerid in self._sequence:
87             if peerid in pending:
88                 d, shares = pending.pop(peerid)
89                 eventually(d.callback, shares)
90         for (d, shares) in pending.values():
91             eventually(d.callback, shares)
92
93     def write(self, peerid, storage_index, shnum, offset, data):
94         if peerid not in self._peers:
95             self._peers[peerid] = {}
96         shares = self._peers[peerid]
97         f = StringIO()
98         f.write(shares.get(shnum, ""))
99         f.seek(offset)
100         f.write(data)
101         shares[shnum] = f.getvalue()
102
103
104 class FakeStorageServer:
105     def __init__(self, peerid, storage):
106         self.peerid = peerid
107         self.storage = storage
108         self.queries = 0
109     def callRemote(self, methname, *args, **kwargs):
110         def _call():
111             meth = getattr(self, methname)
112             return meth(*args, **kwargs)
113         d = fireEventually()
114         d.addCallback(lambda res: _call())
115         return d
116
117     def slot_readv(self, storage_index, shnums, readv):
118         d = self.storage.read(self.peerid, storage_index)
119         def _read(shares):
120             response = {}
121             for shnum in shares:
122                 if shnums and shnum not in shnums:
123                     continue
124                 vector = response[shnum] = []
125                 for (offset, length) in readv:
126                     assert isinstance(offset, (int, long)), offset
127                     assert isinstance(length, (int, long)), length
128                     vector.append(shares[shnum][offset:offset+length])
129             return response
130         d.addCallback(_read)
131         return d
132
133     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
134                                         tw_vectors, read_vector):
135         # always-pass: parrot the test vectors back to them.
136         readv = {}
137         for shnum, (testv, writev, new_length) in tw_vectors.items():
138             for (offset, length, op, specimen) in testv:
139                 assert op in ("le", "eq", "ge")
140             # TODO: this isn't right, the read is controlled by read_vector,
141             # not by testv
142             readv[shnum] = [ specimen
143                              for (offset, length, op, specimen)
144                              in testv ]
145             for (offset, data) in writev:
146                 self.storage.write(self.peerid, storage_index, shnum,
147                                    offset, data)
148         answer = (True, readv)
149         return fireEventually(answer)
150
151
152 # our "FakeClient" has just enough functionality of the real Client to let
153 # the tests run.
154
155 class FakeClient:
156     mutable_file_node_class = FastMutableFileNode
157
158     def __init__(self, num_peers=10):
159         self._storage = FakeStorage()
160         self._num_peers = num_peers
161         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
162                          for i in range(self._num_peers)]
163         self._connections = dict([(peerid, FakeStorageServer(peerid,
164                                                              self._storage))
165                                   for peerid in self._peerids])
166         self.nodeid = "fakenodeid"
167
168     def log(self, msg, **kw):
169         return log.msg(msg, **kw)
170
171     def get_renewal_secret(self):
172         return "I hereby permit you to renew my files"
173     def get_cancel_secret(self):
174         return "I hereby permit you to cancel my leases"
175
176     def create_mutable_file(self, contents=""):
177         n = self.mutable_file_node_class(self)
178         d = n.create(contents)
179         d.addCallback(lambda res: n)
180         return d
181
182     def notify_retrieve(self, r):
183         pass
184     def notify_publish(self, p, size):
185         pass
186     def notify_mapupdate(self, u):
187         pass
188
189     def create_node_from_uri(self, u):
190         u = IURI(u)
191         assert IMutableFileURI.providedBy(u), u
192         res = self.mutable_file_node_class(self).init_from_uri(u)
193         return res
194
195     def get_permuted_peers(self, service_name, key):
196         """
197         @return: list of (peerid, connection,)
198         """
199         results = []
200         for (peerid, connection) in self._connections.items():
201             assert isinstance(peerid, str)
202             permuted = sha.new(key + peerid).digest()
203             results.append((permuted, peerid, connection))
204         results.sort()
205         results = [ (r[1],r[2]) for r in results]
206         return results
207
208     def upload(self, uploadable):
209         assert IUploadable.providedBy(uploadable)
210         d = uploadable.get_size()
211         d.addCallback(lambda length: uploadable.read(length))
212         #d.addCallback(self.create_mutable_file)
213         def _got_data(datav):
214             data = "".join(datav)
215             #newnode = FastMutableFileNode(self)
216             return uri.LiteralFileURI(data)
217         d.addCallback(_got_data)
218         return d
219
220
221 def flip_bit(original, byte_offset):
222     return (original[:byte_offset] +
223             chr(ord(original[byte_offset]) ^ 0x01) +
224             original[byte_offset+1:])
225
226 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
227     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
228     # list of shnums to corrupt.
229     for peerid in s._peers:
230         shares = s._peers[peerid]
231         for shnum in shares:
232             if (shnums_to_corrupt is not None
233                 and shnum not in shnums_to_corrupt):
234                 continue
235             data = shares[shnum]
236             (version,
237              seqnum,
238              root_hash,
239              IV,
240              k, N, segsize, datalen,
241              o) = unpack_header(data)
242             if isinstance(offset, tuple):
243                 offset1, offset2 = offset
244             else:
245                 offset1 = offset
246                 offset2 = 0
247             if offset1 == "pubkey":
248                 real_offset = 107
249             elif offset1 in o:
250                 real_offset = o[offset1]
251             else:
252                 real_offset = offset1
253             real_offset = int(real_offset) + offset2 + offset_offset
254             assert isinstance(real_offset, int), offset
255             shares[shnum] = flip_bit(data, real_offset)
256     return res
257
258 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
259     def setUp(self):
260         self.client = FakeClient()
261
262     def test_create(self):
263         d = self.client.create_mutable_file()
264         def _created(n):
265             self.failUnless(isinstance(n, FastMutableFileNode))
266             peer0 = self.client._peerids[0]
267             shnums = self.client._storage._peers[peer0].keys()
268             self.failUnlessEqual(len(shnums), 1)
269         d.addCallback(_created)
270         return d
271
272     def test_serialize(self):
273         n = MutableFileNode(self.client)
274         calls = []
275         def _callback(*args, **kwargs):
276             self.failUnlessEqual(args, (4,) )
277             self.failUnlessEqual(kwargs, {"foo": 5})
278             calls.append(1)
279             return 6
280         d = n._do_serialized(_callback, 4, foo=5)
281         def _check_callback(res):
282             self.failUnlessEqual(res, 6)
283             self.failUnlessEqual(calls, [1])
284         d.addCallback(_check_callback)
285
286         def _errback():
287             raise ValueError("heya")
288         d.addCallback(lambda res:
289                       self.shouldFail(ValueError, "_check_errback", "heya",
290                                       n._do_serialized, _errback))
291         return d
292
293     def test_upload_and_download(self):
294         d = self.client.create_mutable_file()
295         def _created(n):
296             d = defer.succeed(None)
297             d.addCallback(lambda res: n.get_servermap(MODE_READ))
298             d.addCallback(lambda smap: smap.dump(StringIO()))
299             d.addCallback(lambda sio:
300                           self.failUnless("3-of-10" in sio.getvalue()))
301             d.addCallback(lambda res: n.overwrite("contents 1"))
302             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
303             d.addCallback(lambda res: n.download_best_version())
304             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
305             d.addCallback(lambda res: n.overwrite("contents 2"))
306             d.addCallback(lambda res: n.download_best_version())
307             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
308             d.addCallback(lambda res: n.download(download.Data()))
309             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
310             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
311             d.addCallback(lambda smap: n.upload("contents 3", smap))
312             d.addCallback(lambda res: n.download_best_version())
313             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
314             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
315             d.addCallback(lambda smap:
316                           n.download_version(smap,
317                                              smap.best_recoverable_version()))
318             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
319             # test a file that is large enough to overcome the
320             # mapupdate-to-retrieve data caching (i.e. make the shares larger
321             # than the default readsize, which is 2000 bytes). A 15kB file
322             # will have 5kB shares.
323             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
324             d.addCallback(lambda res: n.download_best_version())
325             d.addCallback(lambda res:
326                           self.failUnlessEqual(res, "large size file" * 1000))
327             return d
328         d.addCallback(_created)
329         return d
330
331     def test_create_with_initial_contents(self):
332         d = self.client.create_mutable_file("contents 1")
333         def _created(n):
334             d = n.download_best_version()
335             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
336             d.addCallback(lambda res: n.overwrite("contents 2"))
337             d.addCallback(lambda res: n.download_best_version())
338             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
339             return d
340         d.addCallback(_created)
341         return d
342
343     def test_create_with_too_large_contents(self):
344         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
345         d = self.shouldFail(FileTooLargeError, "too_large",
346                             "SDMF is limited to one segment, and %d > %d" %
347                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
348                             self.client.create_mutable_file, BIG)
349         d.addCallback(lambda res: self.client.create_mutable_file("small"))
350         def _created(n):
351             return self.shouldFail(FileTooLargeError, "too_large_2",
352                                    "SDMF is limited to one segment, and %d > %d" %
353                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
354                                    n.overwrite, BIG)
355         d.addCallback(_created)
356         return d
357
358     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
359         d = n.get_servermap(MODE_READ)
360         d.addCallback(lambda servermap: servermap.best_recoverable_version())
361         d.addCallback(lambda verinfo:
362                       self.failUnlessEqual(verinfo[0], expected_seqnum))
363         return d
364
365     def test_modify(self):
366         def _modifier(old_contents):
367             return old_contents + "line2"
368         def _non_modifier(old_contents):
369             return old_contents
370         def _none_modifier(old_contents):
371             return None
372         def _error_modifier(old_contents):
373             raise ValueError("oops")
374         def _toobig_modifier(old_contents):
375             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
376         calls = []
377         def _ucw_error_modifier(old_contents):
378             # simulate an UncoordinatedWriteError once
379             calls.append(1)
380             if len(calls) <= 1:
381                 raise UncoordinatedWriteError("simulated")
382             return old_contents + "line3"
383
384         d = self.client.create_mutable_file("line1")
385         def _created(n):
386             d = n.modify(_modifier)
387             d.addCallback(lambda res: n.download_best_version())
388             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
389             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
390
391             d.addCallback(lambda res: n.modify(_non_modifier))
392             d.addCallback(lambda res: n.download_best_version())
393             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
394             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
395
396             d.addCallback(lambda res: n.modify(_none_modifier))
397             d.addCallback(lambda res: n.download_best_version())
398             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
399             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
400
401             d.addCallback(lambda res:
402                           self.shouldFail(ValueError, "error_modifier", None,
403                                           n.modify, _error_modifier))
404             d.addCallback(lambda res: n.download_best_version())
405             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
406             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
407
408             d.addCallback(lambda res:
409                           self.shouldFail(FileTooLargeError, "toobig_modifier",
410                                           "SDMF is limited to one segment",
411                                           n.modify, _toobig_modifier))
412             d.addCallback(lambda res: n.download_best_version())
413             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
414             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
415
416             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
417             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
418             d.addCallback(lambda res: n.download_best_version())
419             d.addCallback(lambda res: self.failUnlessEqual(res,
420                                                            "line1line2line3"))
421             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
422
423             return d
424         d.addCallback(_created)
425         return d
426
427     def test_modify_backoffer(self):
428         def _modifier(old_contents):
429             return old_contents + "line2"
430         calls = []
431         def _ucw_error_modifier(old_contents):
432             # simulate an UncoordinatedWriteError once
433             calls.append(1)
434             if len(calls) <= 1:
435                 raise UncoordinatedWriteError("simulated")
436             return old_contents + "line3"
437         def _always_ucw_error_modifier(old_contents):
438             raise UncoordinatedWriteError("simulated")
439         def _backoff_stopper(node, f):
440             return f
441         def _backoff_pauser(node, f):
442             d = defer.Deferred()
443             reactor.callLater(0.5, d.callback, None)
444             return d
445
446         # the give-up-er will hit its maximum retry count quickly
447         giveuper = BackoffAgent()
448         giveuper._delay = 0.1
449         giveuper.factor = 1
450
451         d = self.client.create_mutable_file("line1")
452         def _created(n):
453             d = n.modify(_modifier)
454             d.addCallback(lambda res: n.download_best_version())
455             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
456             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
457
458             d.addCallback(lambda res:
459                           self.shouldFail(UncoordinatedWriteError,
460                                           "_backoff_stopper", None,
461                                           n.modify, _ucw_error_modifier,
462                                           _backoff_stopper))
463             d.addCallback(lambda res: n.download_best_version())
464             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
465             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
466
467             def _reset_ucw_error_modifier(res):
468                 calls[:] = []
469                 return res
470             d.addCallback(_reset_ucw_error_modifier)
471             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
472                                                _backoff_pauser))
473             d.addCallback(lambda res: n.download_best_version())
474             d.addCallback(lambda res: self.failUnlessEqual(res,
475                                                            "line1line2line3"))
476             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
477
478             d.addCallback(lambda res:
479                           self.shouldFail(UncoordinatedWriteError,
480                                           "giveuper", None,
481                                           n.modify, _always_ucw_error_modifier,
482                                           giveuper.delay))
483             d.addCallback(lambda res: n.download_best_version())
484             d.addCallback(lambda res: self.failUnlessEqual(res,
485                                                            "line1line2line3"))
486             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
487
488             return d
489         d.addCallback(_created)
490         return d
491
492     def test_upload_and_download_full_size_keys(self):
493         self.client.mutable_file_node_class = MutableFileNode
494         d = self.client.create_mutable_file()
495         def _created(n):
496             d = defer.succeed(None)
497             d.addCallback(lambda res: n.get_servermap(MODE_READ))
498             d.addCallback(lambda smap: smap.dump(StringIO()))
499             d.addCallback(lambda sio:
500                           self.failUnless("3-of-10" in sio.getvalue()))
501             d.addCallback(lambda res: n.overwrite("contents 1"))
502             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
503             d.addCallback(lambda res: n.download_best_version())
504             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
505             d.addCallback(lambda res: n.overwrite("contents 2"))
506             d.addCallback(lambda res: n.download_best_version())
507             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
508             d.addCallback(lambda res: n.download(download.Data()))
509             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
510             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
511             d.addCallback(lambda smap: n.upload("contents 3", smap))
512             d.addCallback(lambda res: n.download_best_version())
513             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
514             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
515             d.addCallback(lambda smap:
516                           n.download_version(smap,
517                                              smap.best_recoverable_version()))
518             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
519             return d
520         d.addCallback(_created)
521         return d
522
523
524 class MakeShares(unittest.TestCase):
525     def test_encrypt(self):
526         c = FakeClient()
527         fn = FastMutableFileNode(c)
528         CONTENTS = "some initial contents"
529         d = fn.create(CONTENTS)
530         def _created(res):
531             p = Publish(fn, None)
532             p.salt = "SALT" * 4
533             p.readkey = "\x00" * 16
534             p.newdata = CONTENTS
535             p.required_shares = 3
536             p.total_shares = 10
537             p.setup_encoding_parameters()
538             return p._encrypt_and_encode()
539         d.addCallback(_created)
540         def _done(shares_and_shareids):
541             (shares, share_ids) = shares_and_shareids
542             self.failUnlessEqual(len(shares), 10)
543             for sh in shares:
544                 self.failUnless(isinstance(sh, str))
545                 self.failUnlessEqual(len(sh), 7)
546             self.failUnlessEqual(len(share_ids), 10)
547         d.addCallback(_done)
548         return d
549
550     def test_generate(self):
551         c = FakeClient()
552         fn = FastMutableFileNode(c)
553         CONTENTS = "some initial contents"
554         d = fn.create(CONTENTS)
555         def _created(res):
556             p = Publish(fn, None)
557             self._p = p
558             p.newdata = CONTENTS
559             p.required_shares = 3
560             p.total_shares = 10
561             p.setup_encoding_parameters()
562             p._new_seqnum = 3
563             p.salt = "SALT" * 4
564             # make some fake shares
565             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
566             p._privkey = fn.get_privkey()
567             p._encprivkey = fn.get_encprivkey()
568             p._pubkey = fn.get_pubkey()
569             return p._generate_shares(shares_and_ids)
570         d.addCallback(_created)
571         def _generated(res):
572             p = self._p
573             final_shares = p.shares
574             root_hash = p.root_hash
575             self.failUnlessEqual(len(root_hash), 32)
576             self.failUnless(isinstance(final_shares, dict))
577             self.failUnlessEqual(len(final_shares), 10)
578             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
579             for i,sh in final_shares.items():
580                 self.failUnless(isinstance(sh, str))
581                 # feed the share through the unpacker as a sanity-check
582                 pieces = unpack_share(sh)
583                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
584                  pubkey, signature, share_hash_chain, block_hash_tree,
585                  share_data, enc_privkey) = pieces
586                 self.failUnlessEqual(u_seqnum, 3)
587                 self.failUnlessEqual(u_root_hash, root_hash)
588                 self.failUnlessEqual(k, 3)
589                 self.failUnlessEqual(N, 10)
590                 self.failUnlessEqual(segsize, 21)
591                 self.failUnlessEqual(datalen, len(CONTENTS))
592                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
593                 sig_material = struct.pack(">BQ32s16s BBQQ",
594                                            0, p._new_seqnum, root_hash, IV,
595                                            k, N, segsize, datalen)
596                 self.failUnless(p._pubkey.verify(sig_material, signature))
597                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
598                 self.failUnless(isinstance(share_hash_chain, dict))
599                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
600                 for shnum,share_hash in share_hash_chain.items():
601                     self.failUnless(isinstance(shnum, int))
602                     self.failUnless(isinstance(share_hash, str))
603                     self.failUnlessEqual(len(share_hash), 32)
604                 self.failUnless(isinstance(block_hash_tree, list))
605                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
606                 self.failUnlessEqual(IV, "SALT"*4)
607                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
608                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
609         d.addCallback(_generated)
610         return d
611
612     # TODO: when we publish to 20 peers, we should get one share per peer on 10
613     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
614     # when we publish to zero peers, we should get a NotEnoughSharesError
615
616 class Servermap(unittest.TestCase):
617     def setUp(self):
618         # publish a file and create shares, which can then be manipulated
619         # later.
620         num_peers = 20
621         self._client = FakeClient(num_peers)
622         self._storage = self._client._storage
623         d = self._client.create_mutable_file("New contents go here")
624         def _created(node):
625             self._fn = node
626             self._fn2 = self._client.create_node_from_uri(node.get_uri())
627         d.addCallback(_created)
628         return d
629
630     def make_servermap(self, mode=MODE_CHECK, fn=None):
631         if fn is None:
632             fn = self._fn
633         smu = ServermapUpdater(fn, ServerMap(), mode)
634         d = smu.update()
635         return d
636
637     def update_servermap(self, oldmap, mode=MODE_CHECK):
638         smu = ServermapUpdater(self._fn, oldmap, mode)
639         d = smu.update()
640         return d
641
642     def failUnlessOneRecoverable(self, sm, num_shares):
643         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
644         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
645         best = sm.best_recoverable_version()
646         self.failIfEqual(best, None)
647         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
648         self.failUnlessEqual(len(sm.shares_available()), 1)
649         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
650         shnum, peerids = sm.make_sharemap().items()[0]
651         peerid = list(peerids)[0]
652         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
653         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
654         return sm
655
656     def test_basic(self):
657         d = defer.succeed(None)
658         ms = self.make_servermap
659         us = self.update_servermap
660
661         d.addCallback(lambda res: ms(mode=MODE_CHECK))
662         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
663         d.addCallback(lambda res: ms(mode=MODE_WRITE))
664         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
665         d.addCallback(lambda res: ms(mode=MODE_READ))
666         # this more stops at k+epsilon, and epsilon=k, so 6 shares
667         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
668         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
669         # this mode stops at 'k' shares
670         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
671
672         # and can we re-use the same servermap? Note that these are sorted in
673         # increasing order of number of servers queried, since once a server
674         # gets into the servermap, we'll always ask it for an update.
675         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
676         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
677         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
678         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
679         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
680         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
681         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
682         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
683         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
684
685         return d
686
687     def test_fetch_privkey(self):
688         d = defer.succeed(None)
689         # use the sibling filenode (which hasn't been used yet), and make
690         # sure it can fetch the privkey. The file is small, so the privkey
691         # will be fetched on the first (query) pass.
692         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
693         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
694
695         # create a new file, which is large enough to knock the privkey out
696         # of the early part of the file
697         LARGE = "These are Larger contents" * 200 # about 5KB
698         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
699         def _created(large_fn):
700             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
701             return self.make_servermap(MODE_WRITE, large_fn2)
702         d.addCallback(_created)
703         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
704         return d
705
706     def test_mark_bad(self):
707         d = defer.succeed(None)
708         ms = self.make_servermap
709         us = self.update_servermap
710
711         d.addCallback(lambda res: ms(mode=MODE_READ))
712         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
713         def _made_map(sm):
714             v = sm.best_recoverable_version()
715             vm = sm.make_versionmap()
716             shares = list(vm[v])
717             self.failUnlessEqual(len(shares), 6)
718             self._corrupted = set()
719             # mark the first 5 shares as corrupt, then update the servermap.
720             # The map should not have the marked shares it in any more, and
721             # new shares should be found to replace the missing ones.
722             for (shnum, peerid, timestamp) in shares:
723                 if shnum < 5:
724                     self._corrupted.add( (peerid, shnum) )
725                     sm.mark_bad_share(peerid, shnum)
726             return self.update_servermap(sm, MODE_WRITE)
727         d.addCallback(_made_map)
728         def _check_map(sm):
729             # this should find all 5 shares that weren't marked bad
730             v = sm.best_recoverable_version()
731             vm = sm.make_versionmap()
732             shares = list(vm[v])
733             for (peerid, shnum) in self._corrupted:
734                 peer_shares = sm.shares_on_peer(peerid)
735                 self.failIf(shnum in peer_shares,
736                             "%d was in %s" % (shnum, peer_shares))
737             self.failUnlessEqual(len(shares), 5)
738         d.addCallback(_check_map)
739         return d
740
741     def failUnlessNoneRecoverable(self, sm):
742         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
743         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
744         best = sm.best_recoverable_version()
745         self.failUnlessEqual(best, None)
746         self.failUnlessEqual(len(sm.shares_available()), 0)
747
748     def test_no_shares(self):
749         self._client._storage._peers = {} # delete all shares
750         ms = self.make_servermap
751         d = defer.succeed(None)
752
753         d.addCallback(lambda res: ms(mode=MODE_CHECK))
754         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
755
756         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
757         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
758
759         d.addCallback(lambda res: ms(mode=MODE_WRITE))
760         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
761
762         d.addCallback(lambda res: ms(mode=MODE_READ))
763         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
764
765         return d
766
767     def failUnlessNotQuiteEnough(self, sm):
768         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
769         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
770         best = sm.best_recoverable_version()
771         self.failUnlessEqual(best, None)
772         self.failUnlessEqual(len(sm.shares_available()), 1)
773         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
774         return sm
775
776     def test_not_quite_enough_shares(self):
777         s = self._client._storage
778         ms = self.make_servermap
779         num_shares = len(s._peers)
780         for peerid in s._peers:
781             s._peers[peerid] = {}
782             num_shares -= 1
783             if num_shares == 2:
784                 break
785         # now there ought to be only two shares left
786         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
787
788         d = defer.succeed(None)
789
790         d.addCallback(lambda res: ms(mode=MODE_CHECK))
791         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
792         d.addCallback(lambda sm:
793                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
794         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
795         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
796         d.addCallback(lambda res: ms(mode=MODE_WRITE))
797         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
798         d.addCallback(lambda res: ms(mode=MODE_READ))
799         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
800
801         return d
802
803
804
805 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
806     def setUp(self):
807         # publish a file and create shares, which can then be manipulated
808         # later.
809         self.CONTENTS = "New contents go here" * 1000
810         num_peers = 20
811         self._client = FakeClient(num_peers)
812         self._storage = self._client._storage
813         d = self._client.create_mutable_file(self.CONTENTS)
814         def _created(node):
815             self._fn = node
816         d.addCallback(_created)
817         return d
818
819     def make_servermap(self, mode=MODE_READ, oldmap=None):
820         if oldmap is None:
821             oldmap = ServerMap()
822         smu = ServermapUpdater(self._fn, oldmap, mode)
823         d = smu.update()
824         return d
825
826     def abbrev_verinfo(self, verinfo):
827         if verinfo is None:
828             return None
829         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
830          offsets_tuple) = verinfo
831         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
832
833     def abbrev_verinfo_dict(self, verinfo_d):
834         output = {}
835         for verinfo,value in verinfo_d.items():
836             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
837              offsets_tuple) = verinfo
838             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
839         return output
840
841     def dump_servermap(self, servermap):
842         print "SERVERMAP", servermap
843         print "RECOVERABLE", [self.abbrev_verinfo(v)
844                               for v in servermap.recoverable_versions()]
845         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
846         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
847
848     def do_download(self, servermap, version=None):
849         if version is None:
850             version = servermap.best_recoverable_version()
851         r = Retrieve(self._fn, servermap, version)
852         return r.download()
853
854     def test_basic(self):
855         d = self.make_servermap()
856         def _do_retrieve(servermap):
857             self._smap = servermap
858             #self.dump_servermap(servermap)
859             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
860             return self.do_download(servermap)
861         d.addCallback(_do_retrieve)
862         def _retrieved(new_contents):
863             self.failUnlessEqual(new_contents, self.CONTENTS)
864         d.addCallback(_retrieved)
865         # we should be able to re-use the same servermap, both with and
866         # without updating it.
867         d.addCallback(lambda res: self.do_download(self._smap))
868         d.addCallback(_retrieved)
869         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
870         d.addCallback(lambda res: self.do_download(self._smap))
871         d.addCallback(_retrieved)
872         # clobbering the pubkey should make the servermap updater re-fetch it
873         def _clobber_pubkey(res):
874             self._fn._pubkey = None
875         d.addCallback(_clobber_pubkey)
876         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
877         d.addCallback(lambda res: self.do_download(self._smap))
878         d.addCallback(_retrieved)
879         return d
880
881     def test_all_shares_vanished(self):
882         d = self.make_servermap()
883         def _remove_shares(servermap):
884             for shares in self._storage._peers.values():
885                 shares.clear()
886             d1 = self.shouldFail(NotEnoughSharesError,
887                                  "test_all_shares_vanished",
888                                  "ran out of peers",
889                                  self.do_download, servermap)
890             return d1
891         d.addCallback(_remove_shares)
892         return d
893
894     def test_no_servers(self):
895         c2 = FakeClient(0)
896         self._fn._client = c2
897         # if there are no servers, then a MODE_READ servermap should come
898         # back empty
899         d = self.make_servermap()
900         def _check_servermap(servermap):
901             self.failUnlessEqual(servermap.best_recoverable_version(), None)
902             self.failIf(servermap.recoverable_versions())
903             self.failIf(servermap.unrecoverable_versions())
904             self.failIf(servermap.all_peers())
905         d.addCallback(_check_servermap)
906         return d
907     test_no_servers.timeout = 15
908
909     def test_no_servers_download(self):
910         c2 = FakeClient(0)
911         self._fn._client = c2
912         d = self.shouldFail(UnrecoverableFileError,
913                             "test_no_servers_download",
914                             "no recoverable versions",
915                             self._fn.download_best_version)
916         def _restore(res):
917             # a failed download that occurs while we aren't connected to
918             # anybody should not prevent a subsequent download from working.
919             # This isn't quite the webapi-driven test that #463 wants, but it
920             # should be close enough.
921             self._fn._client = self._client
922             return self._fn.download_best_version()
923         def _retrieved(new_contents):
924             self.failUnlessEqual(new_contents, self.CONTENTS)
925         d.addCallback(_restore)
926         d.addCallback(_retrieved)
927         return d
928     test_no_servers_download.timeout = 15
929
930     def _test_corrupt_all(self, offset, substring,
931                           should_succeed=False, corrupt_early=True,
932                           failure_checker=None):
933         d = defer.succeed(None)
934         if corrupt_early:
935             d.addCallback(corrupt, self._storage, offset)
936         d.addCallback(lambda res: self.make_servermap())
937         if not corrupt_early:
938             d.addCallback(corrupt, self._storage, offset)
939         def _do_retrieve(servermap):
940             ver = servermap.best_recoverable_version()
941             if ver is None and not should_succeed:
942                 # no recoverable versions == not succeeding. The problem
943                 # should be noted in the servermap's list of problems.
944                 if substring:
945                     allproblems = [str(f) for f in servermap.problems]
946                     self.failUnless(substring in "".join(allproblems))
947                 return servermap
948             if should_succeed:
949                 d1 = self._fn.download_version(servermap, ver)
950                 d1.addCallback(lambda new_contents:
951                                self.failUnlessEqual(new_contents, self.CONTENTS))
952             else:
953                 d1 = self.shouldFail(NotEnoughSharesError,
954                                      "_corrupt_all(offset=%s)" % (offset,),
955                                      substring,
956                                      self._fn.download_version, servermap, ver)
957             if failure_checker:
958                 d1.addCallback(failure_checker)
959             d1.addCallback(lambda res: servermap)
960             return d1
961         d.addCallback(_do_retrieve)
962         return d
963
964     def test_corrupt_all_verbyte(self):
965         # when the version byte is not 0, we hit an assertion error in
966         # unpack_share().
967         d = self._test_corrupt_all(0, "AssertionError")
968         def _check_servermap(servermap):
969             # and the dump should mention the problems
970             s = StringIO()
971             dump = servermap.dump(s).getvalue()
972             self.failUnless("10 PROBLEMS" in dump, dump)
973         d.addCallback(_check_servermap)
974         return d
975
976     def test_corrupt_all_seqnum(self):
977         # a corrupt sequence number will trigger a bad signature
978         return self._test_corrupt_all(1, "signature is invalid")
979
980     def test_corrupt_all_R(self):
981         # a corrupt root hash will trigger a bad signature
982         return self._test_corrupt_all(9, "signature is invalid")
983
984     def test_corrupt_all_IV(self):
985         # a corrupt salt/IV will trigger a bad signature
986         return self._test_corrupt_all(41, "signature is invalid")
987
988     def test_corrupt_all_k(self):
989         # a corrupt 'k' will trigger a bad signature
990         return self._test_corrupt_all(57, "signature is invalid")
991
992     def test_corrupt_all_N(self):
993         # a corrupt 'N' will trigger a bad signature
994         return self._test_corrupt_all(58, "signature is invalid")
995
996     def test_corrupt_all_segsize(self):
997         # a corrupt segsize will trigger a bad signature
998         return self._test_corrupt_all(59, "signature is invalid")
999
1000     def test_corrupt_all_datalen(self):
1001         # a corrupt data length will trigger a bad signature
1002         return self._test_corrupt_all(67, "signature is invalid")
1003
1004     def test_corrupt_all_pubkey(self):
1005         # a corrupt pubkey won't match the URI's fingerprint. We need to
1006         # remove the pubkey from the filenode, or else it won't bother trying
1007         # to update it.
1008         self._fn._pubkey = None
1009         return self._test_corrupt_all("pubkey",
1010                                       "pubkey doesn't match fingerprint")
1011
1012     def test_corrupt_all_sig(self):
1013         # a corrupt signature is a bad one
1014         # the signature runs from about [543:799], depending upon the length
1015         # of the pubkey
1016         return self._test_corrupt_all("signature", "signature is invalid")
1017
1018     def test_corrupt_all_share_hash_chain_number(self):
1019         # a corrupt share hash chain entry will show up as a bad hash. If we
1020         # mangle the first byte, that will look like a bad hash number,
1021         # causing an IndexError
1022         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1023
1024     def test_corrupt_all_share_hash_chain_hash(self):
1025         # a corrupt share hash chain entry will show up as a bad hash. If we
1026         # mangle a few bytes in, that will look like a bad hash.
1027         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1028
1029     def test_corrupt_all_block_hash_tree(self):
1030         return self._test_corrupt_all("block_hash_tree",
1031                                       "block hash tree failure")
1032
1033     def test_corrupt_all_block(self):
1034         return self._test_corrupt_all("share_data", "block hash tree failure")
1035
1036     def test_corrupt_all_encprivkey(self):
1037         # a corrupted privkey won't even be noticed by the reader, only by a
1038         # writer.
1039         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1040
1041
1042     def test_corrupt_all_seqnum_late(self):
1043         # corrupting the seqnum between mapupdate and retrieve should result
1044         # in NotEnoughSharesError, since each share will look invalid
1045         def _check(res):
1046             f = res[0]
1047             self.failUnless(f.check(NotEnoughSharesError))
1048             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1049         return self._test_corrupt_all(1, "ran out of peers",
1050                                       corrupt_early=False,
1051                                       failure_checker=_check)
1052
1053     def test_corrupt_all_block_hash_tree_late(self):
1054         def _check(res):
1055             f = res[0]
1056             self.failUnless(f.check(NotEnoughSharesError))
1057         return self._test_corrupt_all("block_hash_tree",
1058                                       "block hash tree failure",
1059                                       corrupt_early=False,
1060                                       failure_checker=_check)
1061
1062
1063     def test_corrupt_all_block_late(self):
1064         def _check(res):
1065             f = res[0]
1066             self.failUnless(f.check(NotEnoughSharesError))
1067         return self._test_corrupt_all("share_data", "block hash tree failure",
1068                                       corrupt_early=False,
1069                                       failure_checker=_check)
1070
1071
1072     def test_basic_pubkey_at_end(self):
1073         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1074         # download to succeed but forcing a bunch of retries first. Note that
1075         # this is rather pessimistic: our Retrieve process will throw away
1076         # the whole share if the pubkey is bad, even though the rest of the
1077         # share might be good.
1078
1079         self._fn._pubkey = None
1080         k = self._fn.get_required_shares()
1081         N = self._fn.get_total_shares()
1082         d = defer.succeed(None)
1083         d.addCallback(corrupt, self._storage, "pubkey",
1084                       shnums_to_corrupt=range(0, N-k))
1085         d.addCallback(lambda res: self.make_servermap())
1086         def _do_retrieve(servermap):
1087             self.failUnless(servermap.problems)
1088             self.failUnless("pubkey doesn't match fingerprint"
1089                             in str(servermap.problems[0]))
1090             ver = servermap.best_recoverable_version()
1091             r = Retrieve(self._fn, servermap, ver)
1092             return r.download()
1093         d.addCallback(_do_retrieve)
1094         d.addCallback(lambda new_contents:
1095                       self.failUnlessEqual(new_contents, self.CONTENTS))
1096         return d
1097
1098     def test_corrupt_some(self):
1099         # corrupt the data of first five shares (so the servermap thinks
1100         # they're good but retrieve marks them as bad), so that the
1101         # MODE_READ set of 6 will be insufficient, forcing node.download to
1102         # retry with more servers.
1103         corrupt(None, self._storage, "share_data", range(5))
1104         d = self.make_servermap()
1105         def _do_retrieve(servermap):
1106             ver = servermap.best_recoverable_version()
1107             self.failUnless(ver)
1108             return self._fn.download_best_version()
1109         d.addCallback(_do_retrieve)
1110         d.addCallback(lambda new_contents:
1111                       self.failUnlessEqual(new_contents, self.CONTENTS))
1112         return d
1113
1114     def test_download_fails(self):
1115         corrupt(None, self._storage, "signature")
1116         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1117                             "no recoverable versions",
1118                             self._fn.download_best_version)
1119         return d
1120
1121
1122 class CheckerMixin:
1123     def check_good(self, r, where):
1124         self.failUnless(r.healthy, where)
1125         self.failIf(r.problems, where)
1126         return r
1127
1128     def check_bad(self, r, where):
1129         self.failIf(r.healthy, where)
1130         return r
1131
1132     def check_expected_failure(self, r, expected_exception, substring, where):
1133         for (peerid, storage_index, shnum, f) in r.problems:
1134             if f.check(expected_exception):
1135                 self.failUnless(substring in str(f),
1136                                 "%s: substring '%s' not in '%s'" %
1137                                 (where, substring, str(f)))
1138                 return
1139         self.fail("%s: didn't see expected exception %s in problems %s" %
1140                   (where, expected_exception, r.problems))
1141
1142
1143 class Checker(unittest.TestCase, CheckerMixin):
1144     def setUp(self):
1145         # publish a file and create shares, which can then be manipulated
1146         # later.
1147         self.CONTENTS = "New contents go here" * 1000
1148         num_peers = 20
1149         self._client = FakeClient(num_peers)
1150         self._storage = self._client._storage
1151         d = self._client.create_mutable_file(self.CONTENTS)
1152         def _created(node):
1153             self._fn = node
1154         d.addCallback(_created)
1155         return d
1156
1157
1158     def test_check_good(self):
1159         d = self._fn.check()
1160         d.addCallback(self.check_good, "test_check_good")
1161         return d
1162
1163     def test_check_no_shares(self):
1164         for shares in self._storage._peers.values():
1165             shares.clear()
1166         d = self._fn.check()
1167         d.addCallback(self.check_bad, "test_check_no_shares")
1168         return d
1169
1170     def test_check_not_enough_shares(self):
1171         for shares in self._storage._peers.values():
1172             for shnum in shares.keys():
1173                 if shnum > 0:
1174                     del shares[shnum]
1175         d = self._fn.check()
1176         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1177         return d
1178
1179     def test_check_all_bad_sig(self):
1180         corrupt(None, self._storage, 1) # bad sig
1181         d = self._fn.check()
1182         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1183         return d
1184
1185     def test_check_all_bad_blocks(self):
1186         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1187         # the Checker won't notice this.. it doesn't look at actual data
1188         d = self._fn.check()
1189         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1190         return d
1191
1192     def test_verify_good(self):
1193         d = self._fn.check(verify=True)
1194         d.addCallback(self.check_good, "test_verify_good")
1195         return d
1196
1197     def test_verify_all_bad_sig(self):
1198         corrupt(None, self._storage, 1) # bad sig
1199         d = self._fn.check(verify=True)
1200         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1201         return d
1202
1203     def test_verify_one_bad_sig(self):
1204         corrupt(None, self._storage, 1, [9]) # bad sig
1205         d = self._fn.check(verify=True)
1206         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1207         return d
1208
1209     def test_verify_one_bad_block(self):
1210         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1211         # the Verifier *will* notice this, since it examines every byte
1212         d = self._fn.check(verify=True)
1213         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1214         d.addCallback(self.check_expected_failure,
1215                       CorruptShareError, "block hash tree failure",
1216                       "test_verify_one_bad_block")
1217         return d
1218
1219     def test_verify_one_bad_sharehash(self):
1220         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1221         d = self._fn.check(verify=True)
1222         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1223         d.addCallback(self.check_expected_failure,
1224                       CorruptShareError, "corrupt hashes",
1225                       "test_verify_one_bad_sharehash")
1226         return d
1227
1228     def test_verify_one_bad_encprivkey(self):
1229         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1230         d = self._fn.check(verify=True)
1231         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1232         d.addCallback(self.check_expected_failure,
1233                       CorruptShareError, "invalid privkey",
1234                       "test_verify_one_bad_encprivkey")
1235         return d
1236
1237     def test_verify_one_bad_encprivkey_uncheckable(self):
1238         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1239         readonly_fn = self._fn.get_readonly()
1240         # a read-only node has no way to validate the privkey
1241         d = readonly_fn.check(verify=True)
1242         d.addCallback(self.check_good,
1243                       "test_verify_one_bad_encprivkey_uncheckable")
1244         return d
1245
1246
1247 class MultipleEncodings(unittest.TestCase):
1248     def setUp(self):
1249         self.CONTENTS = "New contents go here"
1250         num_peers = 20
1251         self._client = FakeClient(num_peers)
1252         self._storage = self._client._storage
1253         d = self._client.create_mutable_file(self.CONTENTS)
1254         def _created(node):
1255             self._fn = node
1256         d.addCallback(_created)
1257         return d
1258
1259     def _encode(self, k, n, data):
1260         # encode 'data' into a peerid->shares dict.
1261
1262         fn2 = FastMutableFileNode(self._client)
1263         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1264         # and _fingerprint
1265         fn = self._fn
1266         fn2.init_from_uri(fn.get_uri())
1267         # then we copy over other fields that are normally fetched from the
1268         # existing shares
1269         fn2._pubkey = fn._pubkey
1270         fn2._privkey = fn._privkey
1271         fn2._encprivkey = fn._encprivkey
1272         # and set the encoding parameters to something completely different
1273         fn2._required_shares = k
1274         fn2._total_shares = n
1275
1276         s = self._client._storage
1277         s._peers = {} # clear existing storage
1278         p2 = Publish(fn2, None)
1279         d = p2.publish(data)
1280         def _published(res):
1281             shares = s._peers
1282             s._peers = {}
1283             return shares
1284         d.addCallback(_published)
1285         return d
1286
1287     def make_servermap(self, mode=MODE_READ, oldmap=None):
1288         if oldmap is None:
1289             oldmap = ServerMap()
1290         smu = ServermapUpdater(self._fn, oldmap, mode)
1291         d = smu.update()
1292         return d
1293
1294     def test_multiple_encodings(self):
1295         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1296         # then mix up the shares, to make sure that download survives seeing
1297         # a variety of encodings. This is actually kind of tricky to set up.
1298
1299         contents1 = "Contents for encoding 1 (3-of-10) go here"
1300         contents2 = "Contents for encoding 2 (4-of-9) go here"
1301         contents3 = "Contents for encoding 3 (4-of-7) go here"
1302
1303         # we make a retrieval object that doesn't know what encoding
1304         # parameters to use
1305         fn3 = FastMutableFileNode(self._client)
1306         fn3.init_from_uri(self._fn.get_uri())
1307
1308         # now we upload a file through fn1, and grab its shares
1309         d = self._encode(3, 10, contents1)
1310         def _encoded_1(shares):
1311             self._shares1 = shares
1312         d.addCallback(_encoded_1)
1313         d.addCallback(lambda res: self._encode(4, 9, contents2))
1314         def _encoded_2(shares):
1315             self._shares2 = shares
1316         d.addCallback(_encoded_2)
1317         d.addCallback(lambda res: self._encode(4, 7, contents3))
1318         def _encoded_3(shares):
1319             self._shares3 = shares
1320         d.addCallback(_encoded_3)
1321
1322         def _merge(res):
1323             log.msg("merging sharelists")
1324             # we merge the shares from the two sets, leaving each shnum in
1325             # its original location, but using a share from set1 or set2
1326             # according to the following sequence:
1327             #
1328             #  4-of-9  a  s2
1329             #  4-of-9  b  s2
1330             #  4-of-7  c   s3
1331             #  4-of-9  d  s2
1332             #  3-of-9  e s1
1333             #  3-of-9  f s1
1334             #  3-of-9  g s1
1335             #  4-of-9  h  s2
1336             #
1337             # so that neither form can be recovered until fetch [f], at which
1338             # point version-s1 (the 3-of-10 form) should be recoverable. If
1339             # the implementation latches on to the first version it sees,
1340             # then s2 will be recoverable at fetch [g].
1341
1342             # Later, when we implement code that handles multiple versions,
1343             # we can use this framework to assert that all recoverable
1344             # versions are retrieved, and test that 'epsilon' does its job
1345
1346             places = [2, 2, 3, 2, 1, 1, 1, 2]
1347
1348             sharemap = {}
1349
1350             for i,peerid in enumerate(self._client._peerids):
1351                 peerid_s = shortnodeid_b2a(peerid)
1352                 for shnum in self._shares1.get(peerid, {}):
1353                     if shnum < len(places):
1354                         which = places[shnum]
1355                     else:
1356                         which = "x"
1357                     self._client._storage._peers[peerid] = peers = {}
1358                     in_1 = shnum in self._shares1[peerid]
1359                     in_2 = shnum in self._shares2.get(peerid, {})
1360                     in_3 = shnum in self._shares3.get(peerid, {})
1361                     #print peerid_s, shnum, which, in_1, in_2, in_3
1362                     if which == 1:
1363                         if in_1:
1364                             peers[shnum] = self._shares1[peerid][shnum]
1365                             sharemap[shnum] = peerid
1366                     elif which == 2:
1367                         if in_2:
1368                             peers[shnum] = self._shares2[peerid][shnum]
1369                             sharemap[shnum] = peerid
1370                     elif which == 3:
1371                         if in_3:
1372                             peers[shnum] = self._shares3[peerid][shnum]
1373                             sharemap[shnum] = peerid
1374
1375             # we don't bother placing any other shares
1376             # now sort the sequence so that share 0 is returned first
1377             new_sequence = [sharemap[shnum]
1378                             for shnum in sorted(sharemap.keys())]
1379             self._client._storage._sequence = new_sequence
1380             log.msg("merge done")
1381         d.addCallback(_merge)
1382         d.addCallback(lambda res: fn3.download_best_version())
1383         def _retrieved(new_contents):
1384             # the current specified behavior is "first version recoverable"
1385             self.failUnlessEqual(new_contents, contents1)
1386         d.addCallback(_retrieved)
1387         return d
1388
1389 class MultipleVersions(unittest.TestCase, CheckerMixin):
1390     def setUp(self):
1391         self.CONTENTS = ["Contents 0",
1392                          "Contents 1",
1393                          "Contents 2",
1394                          "Contents 3a",
1395                          "Contents 3b"]
1396         self._copied_shares = {}
1397         num_peers = 20
1398         self._client = FakeClient(num_peers)
1399         self._storage = self._client._storage
1400         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
1401         def _created(node):
1402             self._fn = node
1403             # now create multiple versions of the same file, and accumulate
1404             # their shares, so we can mix and match them later.
1405             d = defer.succeed(None)
1406             d.addCallback(self._copy_shares, 0)
1407             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
1408             d.addCallback(self._copy_shares, 1)
1409             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
1410             d.addCallback(self._copy_shares, 2)
1411             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
1412             d.addCallback(self._copy_shares, 3)
1413             # now we replace all the shares with version s3, and upload a new
1414             # version to get s4b.
1415             rollback = dict([(i,2) for i in range(10)])
1416             d.addCallback(lambda res: self._set_versions(rollback))
1417             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
1418             d.addCallback(self._copy_shares, 4)
1419             # we leave the storage in state 4
1420             return d
1421         d.addCallback(_created)
1422         return d
1423
1424     def _copy_shares(self, ignored, index):
1425         shares = self._client._storage._peers
1426         # we need a deep copy
1427         new_shares = {}
1428         for peerid in shares:
1429             new_shares[peerid] = {}
1430             for shnum in shares[peerid]:
1431                 new_shares[peerid][shnum] = shares[peerid][shnum]
1432         self._copied_shares[index] = new_shares
1433
1434     def _set_versions(self, versionmap):
1435         # versionmap maps shnums to which version (0,1,2,3,4) we want the
1436         # share to be at. Any shnum which is left out of the map will stay at
1437         # its current version.
1438         shares = self._client._storage._peers
1439         oldshares = self._copied_shares
1440         for peerid in shares:
1441             for shnum in shares[peerid]:
1442                 if shnum in versionmap:
1443                     index = versionmap[shnum]
1444                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
1445
1446     def test_multiple_versions(self):
1447         # if we see a mix of versions in the grid, download_best_version
1448         # should get the latest one
1449         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1450         d = self._fn.download_best_version()
1451         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1452         # and the checker should report problems
1453         d.addCallback(lambda res: self._fn.check())
1454         d.addCallback(self.check_bad, "test_multiple_versions")
1455
1456         # but if everything is at version 2, that's what we should download
1457         d.addCallback(lambda res:
1458                       self._set_versions(dict([(i,2) for i in range(10)])))
1459         d.addCallback(lambda res: self._fn.download_best_version())
1460         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1461         # if exactly one share is at version 3, we should still get v2
1462         d.addCallback(lambda res:
1463                       self._set_versions({0:3}))
1464         d.addCallback(lambda res: self._fn.download_best_version())
1465         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1466         # but the servermap should see the unrecoverable version. This
1467         # depends upon the single newer share being queried early.
1468         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1469         def _check_smap(smap):
1470             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1471             newer = smap.unrecoverable_newer_versions()
1472             self.failUnlessEqual(len(newer), 1)
1473             verinfo, health = newer.items()[0]
1474             self.failUnlessEqual(verinfo[0], 4)
1475             self.failUnlessEqual(health, (1,3))
1476             self.failIf(smap.needs_merge())
1477         d.addCallback(_check_smap)
1478         # if we have a mix of two parallel versions (s4a and s4b), we could
1479         # recover either
1480         d.addCallback(lambda res:
1481                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1482                                           1:4,3:4,5:4,7:4,9:4}))
1483         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1484         def _check_smap_mixed(smap):
1485             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1486             newer = smap.unrecoverable_newer_versions()
1487             self.failUnlessEqual(len(newer), 0)
1488             self.failUnless(smap.needs_merge())
1489         d.addCallback(_check_smap_mixed)
1490         d.addCallback(lambda res: self._fn.download_best_version())
1491         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1492                                                   res == self.CONTENTS[4]))
1493         return d
1494
1495     def test_replace(self):
1496         # if we see a mix of versions in the grid, we should be able to
1497         # replace them all with a newer version
1498
1499         # if exactly one share is at version 3, we should download (and
1500         # replace) v2, and the result should be v4. Note that the index we
1501         # give to _set_versions is different than the sequence number.
1502         target = dict([(i,2) for i in range(10)]) # seqnum3
1503         target[0] = 3 # seqnum4
1504         self._set_versions(target)
1505
1506         def _modify(oldversion):
1507             return oldversion + " modified"
1508         d = self._fn.modify(_modify)
1509         d.addCallback(lambda res: self._fn.download_best_version())
1510         expected = self.CONTENTS[2] + " modified"
1511         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1512         # and the servermap should indicate that the outlier was replaced too
1513         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1514         def _check_smap(smap):
1515             self.failUnlessEqual(smap.highest_seqnum(), 5)
1516             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1517             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1518         d.addCallback(_check_smap)
1519         return d
1520
1521
1522 class Utils(unittest.TestCase):
1523     def test_dict_of_sets(self):
1524         ds = DictOfSets()
1525         ds.add(1, "a")
1526         ds.add(2, "b")
1527         ds.add(2, "b")
1528         ds.add(2, "c")
1529         self.failUnlessEqual(ds[1], set(["a"]))
1530         self.failUnlessEqual(ds[2], set(["b", "c"]))
1531         ds.discard(3, "d") # should not raise an exception
1532         ds.discard(2, "b")
1533         self.failUnlessEqual(ds[2], set(["c"]))
1534         ds.discard(2, "c")
1535         self.failIf(2 in ds)
1536
1537     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1538         # we compare this against sets of integers
1539         x = set(range(x_start, x_start+x_length))
1540         y = set(range(y_start, y_start+y_length))
1541         should_be_inside = x.issubset(y)
1542         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1543                                                          y_start, y_length),
1544                              str((x_start, x_length, y_start, y_length)))
1545
1546     def test_cache_inside(self):
1547         c = ResponseCache()
1548         x_start = 10
1549         x_length = 5
1550         for y_start in range(8, 17):
1551             for y_length in range(8):
1552                 self._do_inside(c, x_start, x_length, y_start, y_length)
1553
1554     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1555         # we compare this against sets of integers
1556         x = set(range(x_start, x_start+x_length))
1557         y = set(range(y_start, y_start+y_length))
1558         overlap = bool(x.intersection(y))
1559         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1560                                                       y_start, y_length),
1561                              str((x_start, x_length, y_start, y_length)))
1562
1563     def test_cache_overlap(self):
1564         c = ResponseCache()
1565         x_start = 10
1566         x_length = 5
1567         for y_start in range(8, 17):
1568             for y_length in range(8):
1569                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1570
1571     def test_cache(self):
1572         c = ResponseCache()
1573         # xdata = base62.b2a(os.urandom(100))[:100]
1574         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1575         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1576         nope = (None, None)
1577         c.add("v1", 1, 0, xdata, "time0")
1578         c.add("v1", 1, 2000, ydata, "time1")
1579         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1580         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1581         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1582         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1583         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1584         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1585         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1586         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1587         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1588         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1589         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1590         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1591         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1592         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1593         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1594         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1595         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1596         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1597
1598         # optional: join fragments
1599         c = ResponseCache()
1600         c.add("v1", 1, 0, xdata[:10], "time0")
1601         c.add("v1", 1, 10, xdata[10:20], "time1")
1602         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1603
1604 class Exceptions(unittest.TestCase):
1605     def test_repr(self):
1606         nmde = NeedMoreDataError(100, 50, 100)
1607         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1608         ucwe = UncoordinatedWriteError()
1609         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1610
1611 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1612 # instances which always succeed. So we need a less-fake one.
1613
1614 class IntentionalError(Exception):
1615     pass
1616
1617 class LocalWrapper:
1618     def __init__(self, original):
1619         self.original = original
1620         self.broken = False
1621         self.post_call_notifier = None
1622     def callRemote(self, methname, *args, **kwargs):
1623         def _call():
1624             if self.broken:
1625                 raise IntentionalError("I was asked to break")
1626             meth = getattr(self.original, "remote_" + methname)
1627             return meth(*args, **kwargs)
1628         d = fireEventually()
1629         d.addCallback(lambda res: _call())
1630         if self.post_call_notifier:
1631             d.addCallback(self.post_call_notifier, methname)
1632         return d
1633
1634 class LessFakeClient(FakeClient):
1635
1636     def __init__(self, basedir, num_peers=10):
1637         self._num_peers = num_peers
1638         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1639                          for i in range(self._num_peers)]
1640         self._connections = {}
1641         for peerid in self._peerids:
1642             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1643             make_dirs(peerdir)
1644             ss = storage.StorageServer(peerdir)
1645             ss.setNodeID(peerid)
1646             lw = LocalWrapper(ss)
1647             self._connections[peerid] = lw
1648         self.nodeid = "fakenodeid"
1649
1650
1651 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1652     def test_publish_surprise(self):
1653         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1654         self.client = LessFakeClient(basedir)
1655         d = self.client.create_mutable_file("contents 1")
1656         def _created(n):
1657             d = defer.succeed(None)
1658             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1659             def _got_smap1(smap):
1660                 # stash the old state of the file
1661                 self.old_map = smap
1662             d.addCallback(_got_smap1)
1663             # then modify the file, leaving the old map untouched
1664             d.addCallback(lambda res: log.msg("starting winning write"))
1665             d.addCallback(lambda res: n.overwrite("contents 2"))
1666             # now attempt to modify the file with the old servermap. This
1667             # will look just like an uncoordinated write, in which every
1668             # single share got updated between our mapupdate and our publish
1669             d.addCallback(lambda res: log.msg("starting doomed write"))
1670             d.addCallback(lambda res:
1671                           self.shouldFail(UncoordinatedWriteError,
1672                                           "test_publish_surprise", None,
1673                                           n.upload,
1674                                           "contents 2a", self.old_map))
1675             return d
1676         d.addCallback(_created)
1677         return d
1678
1679     def test_retrieve_surprise(self):
1680         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1681         self.client = LessFakeClient(basedir)
1682         d = self.client.create_mutable_file("contents 1")
1683         def _created(n):
1684             d = defer.succeed(None)
1685             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1686             def _got_smap1(smap):
1687                 # stash the old state of the file
1688                 self.old_map = smap
1689             d.addCallback(_got_smap1)
1690             # then modify the file, leaving the old map untouched
1691             d.addCallback(lambda res: log.msg("starting winning write"))
1692             d.addCallback(lambda res: n.overwrite("contents 2"))
1693             # now attempt to retrieve the old version with the old servermap.
1694             # This will look like someone has changed the file since we
1695             # updated the servermap.
1696             d.addCallback(lambda res: n._cache._clear())
1697             d.addCallback(lambda res: log.msg("starting doomed read"))
1698             d.addCallback(lambda res:
1699                           self.shouldFail(NotEnoughSharesError,
1700                                           "test_retrieve_surprise",
1701                                           "ran out of peers: have 0 shares (k=3)",
1702                                           n.download_version,
1703                                           self.old_map,
1704                                           self.old_map.best_recoverable_version(),
1705                                           ))
1706             return d
1707         d.addCallback(_created)
1708         return d
1709
1710     def test_unexpected_shares(self):
1711         # upload the file, take a servermap, shut down one of the servers,
1712         # upload it again (causing shares to appear on a new server), then
1713         # upload using the old servermap. The last upload should fail with an
1714         # UncoordinatedWriteError, because of the shares that didn't appear
1715         # in the servermap.
1716         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1717         self.client = LessFakeClient(basedir)
1718         d = self.client.create_mutable_file("contents 1")
1719         def _created(n):
1720             d = defer.succeed(None)
1721             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1722             def _got_smap1(smap):
1723                 # stash the old state of the file
1724                 self.old_map = smap
1725                 # now shut down one of the servers
1726                 peer0 = list(smap.make_sharemap()[0])[0]
1727                 self.client._connections.pop(peer0)
1728                 # then modify the file, leaving the old map untouched
1729                 log.msg("starting winning write")
1730                 return n.overwrite("contents 2")
1731             d.addCallback(_got_smap1)
1732             # now attempt to modify the file with the old servermap. This
1733             # will look just like an uncoordinated write, in which every
1734             # single share got updated between our mapupdate and our publish
1735             d.addCallback(lambda res: log.msg("starting doomed write"))
1736             d.addCallback(lambda res:
1737                           self.shouldFail(UncoordinatedWriteError,
1738                                           "test_surprise", None,
1739                                           n.upload,
1740                                           "contents 2a", self.old_map))
1741             return d
1742         d.addCallback(_created)
1743         return d
1744
1745     def test_bad_server(self):
1746         # Break one server, then create the file: the initial publish should
1747         # complete with an alternate server. Breaking a second server should
1748         # not prevent an update from succeeding either.
1749         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1750         self.client = LessFakeClient(basedir, 20)
1751         # to make sure that one of the initial peers is broken, we have to
1752         # get creative. We create the keys, so we can figure out the storage
1753         # index, but we hold off on doing the initial publish until we've
1754         # broken the server on which the first share wants to be stored.
1755         n = FastMutableFileNode(self.client)
1756         d = defer.succeed(None)
1757         d.addCallback(n._generate_pubprivkeys)
1758         d.addCallback(n._generated)
1759         def _break_peer0(res):
1760             si = n.get_storage_index()
1761             peerlist = self.client.get_permuted_peers("storage", si)
1762             peerid0, connection0 = peerlist[0]
1763             peerid1, connection1 = peerlist[1]
1764             connection0.broken = True
1765             self.connection1 = connection1
1766         d.addCallback(_break_peer0)
1767         # now let the initial publish finally happen
1768         d.addCallback(lambda res: n._upload("contents 1", None))
1769         # that ought to work
1770         d.addCallback(lambda res: n.download_best_version())
1771         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1772         # now break the second peer
1773         def _break_peer1(res):
1774             self.connection1.broken = True
1775         d.addCallback(_break_peer1)
1776         d.addCallback(lambda res: n.overwrite("contents 2"))
1777         # that ought to work too
1778         d.addCallback(lambda res: n.download_best_version())
1779         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1780         return d
1781
1782     def test_publish_all_servers_bad(self):
1783         # Break all servers: the publish should fail
1784         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1785         self.client = LessFakeClient(basedir, 20)
1786         for connection in self.client._connections.values():
1787             connection.broken = True
1788         d = self.shouldFail(NotEnoughServersError,
1789                             "test_publish_all_servers_bad",
1790                             "Ran out of non-bad servers",
1791                             self.client.create_mutable_file, "contents")
1792         return d
1793
1794     def test_publish_no_servers(self):
1795         # no servers at all: the publish should fail
1796         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1797         self.client = LessFakeClient(basedir, 0)
1798         d = self.shouldFail(NotEnoughServersError,
1799                             "test_publish_no_servers",
1800                             "Ran out of non-bad servers",
1801                             self.client.create_mutable_file, "contents")
1802         return d
1803     test_publish_no_servers.timeout = 30
1804
1805
1806     def test_privkey_query_error(self):
1807         # when a servermap is updated with MODE_WRITE, it tries to get the
1808         # privkey. Something might go wrong during this query attempt.
1809         self.client = FakeClient(20)
1810         # we need some contents that are large enough to push the privkey out
1811         # of the early part of the file
1812         LARGE = "These are Larger contents" * 200 # about 5KB
1813         d = self.client.create_mutable_file(LARGE)
1814         def _created(n):
1815             self.uri = n.get_uri()
1816             self.n2 = self.client.create_node_from_uri(self.uri)
1817             # we start by doing a map update to figure out which is the first
1818             # server.
1819             return n.get_servermap(MODE_WRITE)
1820         d.addCallback(_created)
1821         d.addCallback(lambda res: fireEventually(res))
1822         def _got_smap1(smap):
1823             peer0 = list(smap.make_sharemap()[0])[0]
1824             # we tell the server to respond to this peer first, so that it
1825             # will be asked for the privkey first
1826             self.client._storage._sequence = [peer0]
1827             # now we make the peer fail their second query
1828             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1829         d.addCallback(_got_smap1)
1830         # now we update a servermap from a new node (which doesn't have the
1831         # privkey yet, forcing it to use a separate privkey query). Each
1832         # query response will trigger a privkey query, and since we're using
1833         # _sequence to make the peer0 response come back first, we'll send it
1834         # a privkey query first, and _sequence will again ensure that the
1835         # peer0 query will also come back before the others, and then
1836         # _special_answers will make sure that the query raises an exception.
1837         # The whole point of these hijinks is to exercise the code in
1838         # _privkey_query_failed. Note that the map-update will succeed, since
1839         # we'll just get a copy from one of the other shares.
1840         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1841         # Using FakeStorage._sequence means there will be read requests still
1842         # floating around.. wait for them to retire
1843         def _cancel_timer(res):
1844             if self.client._storage._pending_timer:
1845                 self.client._storage._pending_timer.cancel()
1846             return res
1847         d.addBoth(_cancel_timer)
1848         return d
1849
1850     def test_privkey_query_missing(self):
1851         # like test_privkey_query_error, but the shares are deleted by the
1852         # second query, instead of raising an exception.
1853         self.client = FakeClient(20)
1854         LARGE = "These are Larger contents" * 200 # about 5KB
1855         d = self.client.create_mutable_file(LARGE)
1856         def _created(n):
1857             self.uri = n.get_uri()
1858             self.n2 = self.client.create_node_from_uri(self.uri)
1859             return n.get_servermap(MODE_WRITE)
1860         d.addCallback(_created)
1861         d.addCallback(lambda res: fireEventually(res))
1862         def _got_smap1(smap):
1863             peer0 = list(smap.make_sharemap()[0])[0]
1864             self.client._storage._sequence = [peer0]
1865             self.client._storage._special_answers[peer0] = ["normal", "none"]
1866         d.addCallback(_got_smap1)
1867         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1868         def _cancel_timer(res):
1869             if self.client._storage._pending_timer:
1870                 self.client._storage._pending_timer.cancel()
1871             return res
1872         d.addBoth(_cancel_timer)
1873         return d