]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
first pass at a mutable repairer. not tested at all yet, but of course all existing...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, storage
8 from allmydata.immutable import download
9 from allmydata.immutable.encode import NotEnoughSharesError
10 from allmydata.util import base32, testutil, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      FileTooLargeError
16 from foolscap.eventual import eventually, fireEventually
17 from foolscap.logging import log
18 import sha
19
20 from allmydata.mutable.node import MutableFileNode, BackoffAgent
21 from allmydata.mutable.common import DictOfSets, ResponseCache, \
22      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
23      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
24      NotEnoughServersError, CorruptShareError
25 from allmydata.mutable.retrieve import Retrieve
26 from allmydata.mutable.publish import Publish
27 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
28 from allmydata.mutable.layout import unpack_header, unpack_share
29
30 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
31 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
32 # of this.
33
34 class FastMutableFileNode(MutableFileNode):
35     SIGNATURE_KEY_SIZE = 522
36
37 # this "FakeStorage" exists to put the share data in RAM and avoid using real
38 # network connections, both to speed up the tests and to reduce the amount of
39 # non-mutable.py code being exercised.
40
41 class FakeStorage:
42     # this class replaces the collection of storage servers, allowing the
43     # tests to examine and manipulate the published shares. It also lets us
44     # control the order in which read queries are answered, to exercise more
45     # of the error-handling code in Retrieve .
46     #
47     # Note that we ignore the storage index: this FakeStorage instance can
48     # only be used for a single storage index.
49
50
51     def __init__(self):
52         self._peers = {}
53         # _sequence is used to cause the responses to occur in a specific
54         # order. If it is in use, then we will defer queries instead of
55         # answering them right away, accumulating the Deferreds in a dict. We
56         # don't know exactly how many queries we'll get, so exactly one
57         # second after the first query arrives, we will release them all (in
58         # order).
59         self._sequence = None
60         self._pending = {}
61         self._pending_timer = None
62         self._special_answers = {}
63
64     def read(self, peerid, storage_index):
65         shares = self._peers.get(peerid, {})
66         if self._special_answers.get(peerid, []):
67             mode = self._special_answers[peerid].pop(0)
68             if mode == "fail":
69                 shares = failure.Failure(IntentionalError())
70             elif mode == "none":
71                 shares = {}
72             elif mode == "normal":
73                 pass
74         if self._sequence is None:
75             return defer.succeed(shares)
76         d = defer.Deferred()
77         if not self._pending:
78             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
79         self._pending[peerid] = (d, shares)
80         return d
81
82     def _fire_readers(self):
83         self._pending_timer = None
84         pending = self._pending
85         self._pending = {}
86         extra = []
87         for peerid in self._sequence:
88             if peerid in pending:
89                 d, shares = pending.pop(peerid)
90                 eventually(d.callback, shares)
91         for (d, shares) in pending.values():
92             eventually(d.callback, shares)
93
94     def write(self, peerid, storage_index, shnum, offset, data):
95         if peerid not in self._peers:
96             self._peers[peerid] = {}
97         shares = self._peers[peerid]
98         f = StringIO()
99         f.write(shares.get(shnum, ""))
100         f.seek(offset)
101         f.write(data)
102         shares[shnum] = f.getvalue()
103
104
105 class FakeStorageServer:
106     def __init__(self, peerid, storage):
107         self.peerid = peerid
108         self.storage = storage
109         self.queries = 0
110     def callRemote(self, methname, *args, **kwargs):
111         def _call():
112             meth = getattr(self, methname)
113             return meth(*args, **kwargs)
114         d = fireEventually()
115         d.addCallback(lambda res: _call())
116         return d
117
118     def slot_readv(self, storage_index, shnums, readv):
119         d = self.storage.read(self.peerid, storage_index)
120         def _read(shares):
121             response = {}
122             for shnum in shares:
123                 if shnums and shnum not in shnums:
124                     continue
125                 vector = response[shnum] = []
126                 for (offset, length) in readv:
127                     assert isinstance(offset, (int, long)), offset
128                     assert isinstance(length, (int, long)), length
129                     vector.append(shares[shnum][offset:offset+length])
130             return response
131         d.addCallback(_read)
132         return d
133
134     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
135                                         tw_vectors, read_vector):
136         # always-pass: parrot the test vectors back to them.
137         readv = {}
138         for shnum, (testv, writev, new_length) in tw_vectors.items():
139             for (offset, length, op, specimen) in testv:
140                 assert op in ("le", "eq", "ge")
141             # TODO: this isn't right, the read is controlled by read_vector,
142             # not by testv
143             readv[shnum] = [ specimen
144                              for (offset, length, op, specimen)
145                              in testv ]
146             for (offset, data) in writev:
147                 self.storage.write(self.peerid, storage_index, shnum,
148                                    offset, data)
149         answer = (True, readv)
150         return fireEventually(answer)
151
152
153 # our "FakeClient" has just enough functionality of the real Client to let
154 # the tests run.
155
156 class FakeClient:
157     mutable_file_node_class = FastMutableFileNode
158
159     def __init__(self, num_peers=10):
160         self._storage = FakeStorage()
161         self._num_peers = num_peers
162         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
163                          for i in range(self._num_peers)]
164         self._connections = dict([(peerid, FakeStorageServer(peerid,
165                                                              self._storage))
166                                   for peerid in self._peerids])
167         self.nodeid = "fakenodeid"
168
169     def log(self, msg, **kw):
170         return log.msg(msg, **kw)
171
172     def get_renewal_secret(self):
173         return "I hereby permit you to renew my files"
174     def get_cancel_secret(self):
175         return "I hereby permit you to cancel my leases"
176
177     def create_mutable_file(self, contents=""):
178         n = self.mutable_file_node_class(self)
179         d = n.create(contents)
180         d.addCallback(lambda res: n)
181         return d
182
183     def notify_retrieve(self, r):
184         pass
185     def notify_publish(self, p, size):
186         pass
187     def notify_mapupdate(self, u):
188         pass
189
190     def create_node_from_uri(self, u):
191         u = IURI(u)
192         assert IMutableFileURI.providedBy(u), u
193         res = self.mutable_file_node_class(self).init_from_uri(u)
194         return res
195
196     def get_permuted_peers(self, service_name, key):
197         """
198         @return: list of (peerid, connection,)
199         """
200         results = []
201         for (peerid, connection) in self._connections.items():
202             assert isinstance(peerid, str)
203             permuted = sha.new(key + peerid).digest()
204             results.append((permuted, peerid, connection))
205         results.sort()
206         results = [ (r[1],r[2]) for r in results]
207         return results
208
209     def upload(self, uploadable):
210         assert IUploadable.providedBy(uploadable)
211         d = uploadable.get_size()
212         d.addCallback(lambda length: uploadable.read(length))
213         #d.addCallback(self.create_mutable_file)
214         def _got_data(datav):
215             data = "".join(datav)
216             #newnode = FastMutableFileNode(self)
217             return uri.LiteralFileURI(data)
218         d.addCallback(_got_data)
219         return d
220
221
222 def flip_bit(original, byte_offset):
223     return (original[:byte_offset] +
224             chr(ord(original[byte_offset]) ^ 0x01) +
225             original[byte_offset+1:])
226
227 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
228     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
229     # list of shnums to corrupt.
230     for peerid in s._peers:
231         shares = s._peers[peerid]
232         for shnum in shares:
233             if (shnums_to_corrupt is not None
234                 and shnum not in shnums_to_corrupt):
235                 continue
236             data = shares[shnum]
237             (version,
238              seqnum,
239              root_hash,
240              IV,
241              k, N, segsize, datalen,
242              o) = unpack_header(data)
243             if isinstance(offset, tuple):
244                 offset1, offset2 = offset
245             else:
246                 offset1 = offset
247                 offset2 = 0
248             if offset1 == "pubkey":
249                 real_offset = 107
250             elif offset1 in o:
251                 real_offset = o[offset1]
252             else:
253                 real_offset = offset1
254             real_offset = int(real_offset) + offset2 + offset_offset
255             assert isinstance(real_offset, int), offset
256             shares[shnum] = flip_bit(data, real_offset)
257     return res
258
259 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
260     def setUp(self):
261         self.client = FakeClient()
262
263     def test_create(self):
264         d = self.client.create_mutable_file()
265         def _created(n):
266             self.failUnless(isinstance(n, FastMutableFileNode))
267             peer0 = self.client._peerids[0]
268             shnums = self.client._storage._peers[peer0].keys()
269             self.failUnlessEqual(len(shnums), 1)
270         d.addCallback(_created)
271         return d
272
273     def test_serialize(self):
274         n = MutableFileNode(self.client)
275         calls = []
276         def _callback(*args, **kwargs):
277             self.failUnlessEqual(args, (4,) )
278             self.failUnlessEqual(kwargs, {"foo": 5})
279             calls.append(1)
280             return 6
281         d = n._do_serialized(_callback, 4, foo=5)
282         def _check_callback(res):
283             self.failUnlessEqual(res, 6)
284             self.failUnlessEqual(calls, [1])
285         d.addCallback(_check_callback)
286
287         def _errback():
288             raise ValueError("heya")
289         d.addCallback(lambda res:
290                       self.shouldFail(ValueError, "_check_errback", "heya",
291                                       n._do_serialized, _errback))
292         return d
293
294     def test_upload_and_download(self):
295         d = self.client.create_mutable_file()
296         def _created(n):
297             d = defer.succeed(None)
298             d.addCallback(lambda res: n.get_servermap(MODE_READ))
299             d.addCallback(lambda smap: smap.dump(StringIO()))
300             d.addCallback(lambda sio:
301                           self.failUnless("3-of-10" in sio.getvalue()))
302             d.addCallback(lambda res: n.overwrite("contents 1"))
303             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
304             d.addCallback(lambda res: n.download_best_version())
305             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
306             d.addCallback(lambda res: n.overwrite("contents 2"))
307             d.addCallback(lambda res: n.download_best_version())
308             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
309             d.addCallback(lambda res: n.download(download.Data()))
310             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
311             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
312             d.addCallback(lambda smap: n.upload("contents 3", smap))
313             d.addCallback(lambda res: n.download_best_version())
314             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
315             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
316             d.addCallback(lambda smap:
317                           n.download_version(smap,
318                                              smap.best_recoverable_version()))
319             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
320             # test a file that is large enough to overcome the
321             # mapupdate-to-retrieve data caching (i.e. make the shares larger
322             # than the default readsize, which is 2000 bytes). A 15kB file
323             # will have 5kB shares.
324             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
325             d.addCallback(lambda res: n.download_best_version())
326             d.addCallback(lambda res:
327                           self.failUnlessEqual(res, "large size file" * 1000))
328             return d
329         d.addCallback(_created)
330         return d
331
332     def test_create_with_initial_contents(self):
333         d = self.client.create_mutable_file("contents 1")
334         def _created(n):
335             d = n.download_best_version()
336             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
337             d.addCallback(lambda res: n.overwrite("contents 2"))
338             d.addCallback(lambda res: n.download_best_version())
339             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
340             return d
341         d.addCallback(_created)
342         return d
343
344     def test_create_with_too_large_contents(self):
345         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
346         d = self.shouldFail(FileTooLargeError, "too_large",
347                             "SDMF is limited to one segment, and %d > %d" %
348                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
349                             self.client.create_mutable_file, BIG)
350         d.addCallback(lambda res: self.client.create_mutable_file("small"))
351         def _created(n):
352             return self.shouldFail(FileTooLargeError, "too_large_2",
353                                    "SDMF is limited to one segment, and %d > %d" %
354                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
355                                    n.overwrite, BIG)
356         d.addCallback(_created)
357         return d
358
359     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
360         d = n.get_servermap(MODE_READ)
361         d.addCallback(lambda servermap: servermap.best_recoverable_version())
362         d.addCallback(lambda verinfo:
363                       self.failUnlessEqual(verinfo[0], expected_seqnum))
364         return d
365
366     def test_modify(self):
367         def _modifier(old_contents):
368             return old_contents + "line2"
369         def _non_modifier(old_contents):
370             return old_contents
371         def _none_modifier(old_contents):
372             return None
373         def _error_modifier(old_contents):
374             raise ValueError("oops")
375         def _toobig_modifier(old_contents):
376             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
377         calls = []
378         def _ucw_error_modifier(old_contents):
379             # simulate an UncoordinatedWriteError once
380             calls.append(1)
381             if len(calls) <= 1:
382                 raise UncoordinatedWriteError("simulated")
383             return old_contents + "line3"
384
385         d = self.client.create_mutable_file("line1")
386         def _created(n):
387             d = n.modify(_modifier)
388             d.addCallback(lambda res: n.download_best_version())
389             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
390             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
391
392             d.addCallback(lambda res: n.modify(_non_modifier))
393             d.addCallback(lambda res: n.download_best_version())
394             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
395             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
396
397             d.addCallback(lambda res: n.modify(_none_modifier))
398             d.addCallback(lambda res: n.download_best_version())
399             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
400             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
401
402             d.addCallback(lambda res:
403                           self.shouldFail(ValueError, "error_modifier", None,
404                                           n.modify, _error_modifier))
405             d.addCallback(lambda res: n.download_best_version())
406             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
407             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
408
409             d.addCallback(lambda res:
410                           self.shouldFail(FileTooLargeError, "toobig_modifier",
411                                           "SDMF is limited to one segment",
412                                           n.modify, _toobig_modifier))
413             d.addCallback(lambda res: n.download_best_version())
414             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
415             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
416
417             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
418             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
419             d.addCallback(lambda res: n.download_best_version())
420             d.addCallback(lambda res: self.failUnlessEqual(res,
421                                                            "line1line2line3"))
422             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
423
424             return d
425         d.addCallback(_created)
426         return d
427
428     def test_modify_backoffer(self):
429         def _modifier(old_contents):
430             return old_contents + "line2"
431         calls = []
432         def _ucw_error_modifier(old_contents):
433             # simulate an UncoordinatedWriteError once
434             calls.append(1)
435             if len(calls) <= 1:
436                 raise UncoordinatedWriteError("simulated")
437             return old_contents + "line3"
438         def _always_ucw_error_modifier(old_contents):
439             raise UncoordinatedWriteError("simulated")
440         def _backoff_stopper(node, f):
441             return f
442         def _backoff_pauser(node, f):
443             d = defer.Deferred()
444             reactor.callLater(0.5, d.callback, None)
445             return d
446
447         # the give-up-er will hit its maximum retry count quickly
448         giveuper = BackoffAgent()
449         giveuper._delay = 0.1
450         giveuper.factor = 1
451
452         d = self.client.create_mutable_file("line1")
453         def _created(n):
454             d = n.modify(_modifier)
455             d.addCallback(lambda res: n.download_best_version())
456             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
457             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
458
459             d.addCallback(lambda res:
460                           self.shouldFail(UncoordinatedWriteError,
461                                           "_backoff_stopper", None,
462                                           n.modify, _ucw_error_modifier,
463                                           _backoff_stopper))
464             d.addCallback(lambda res: n.download_best_version())
465             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
466             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
467
468             def _reset_ucw_error_modifier(res):
469                 calls[:] = []
470                 return res
471             d.addCallback(_reset_ucw_error_modifier)
472             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
473                                                _backoff_pauser))
474             d.addCallback(lambda res: n.download_best_version())
475             d.addCallback(lambda res: self.failUnlessEqual(res,
476                                                            "line1line2line3"))
477             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
478
479             d.addCallback(lambda res:
480                           self.shouldFail(UncoordinatedWriteError,
481                                           "giveuper", None,
482                                           n.modify, _always_ucw_error_modifier,
483                                           giveuper.delay))
484             d.addCallback(lambda res: n.download_best_version())
485             d.addCallback(lambda res: self.failUnlessEqual(res,
486                                                            "line1line2line3"))
487             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
488
489             return d
490         d.addCallback(_created)
491         return d
492
493     def test_upload_and_download_full_size_keys(self):
494         self.client.mutable_file_node_class = MutableFileNode
495         d = self.client.create_mutable_file()
496         def _created(n):
497             d = defer.succeed(None)
498             d.addCallback(lambda res: n.get_servermap(MODE_READ))
499             d.addCallback(lambda smap: smap.dump(StringIO()))
500             d.addCallback(lambda sio:
501                           self.failUnless("3-of-10" in sio.getvalue()))
502             d.addCallback(lambda res: n.overwrite("contents 1"))
503             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
504             d.addCallback(lambda res: n.download_best_version())
505             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
506             d.addCallback(lambda res: n.overwrite("contents 2"))
507             d.addCallback(lambda res: n.download_best_version())
508             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
509             d.addCallback(lambda res: n.download(download.Data()))
510             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
511             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
512             d.addCallback(lambda smap: n.upload("contents 3", smap))
513             d.addCallback(lambda res: n.download_best_version())
514             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
515             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
516             d.addCallback(lambda smap:
517                           n.download_version(smap,
518                                              smap.best_recoverable_version()))
519             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
520             return d
521         d.addCallback(_created)
522         return d
523
524
525 class MakeShares(unittest.TestCase):
526     def test_encrypt(self):
527         c = FakeClient()
528         fn = FastMutableFileNode(c)
529         CONTENTS = "some initial contents"
530         d = fn.create(CONTENTS)
531         def _created(res):
532             p = Publish(fn, None)
533             p.salt = "SALT" * 4
534             p.readkey = "\x00" * 16
535             p.newdata = CONTENTS
536             p.required_shares = 3
537             p.total_shares = 10
538             p.setup_encoding_parameters()
539             return p._encrypt_and_encode()
540         d.addCallback(_created)
541         def _done(shares_and_shareids):
542             (shares, share_ids) = shares_and_shareids
543             self.failUnlessEqual(len(shares), 10)
544             for sh in shares:
545                 self.failUnless(isinstance(sh, str))
546                 self.failUnlessEqual(len(sh), 7)
547             self.failUnlessEqual(len(share_ids), 10)
548         d.addCallback(_done)
549         return d
550
551     def test_generate(self):
552         c = FakeClient()
553         fn = FastMutableFileNode(c)
554         CONTENTS = "some initial contents"
555         d = fn.create(CONTENTS)
556         def _created(res):
557             p = Publish(fn, None)
558             self._p = p
559             p.newdata = CONTENTS
560             p.required_shares = 3
561             p.total_shares = 10
562             p.setup_encoding_parameters()
563             p._new_seqnum = 3
564             p.salt = "SALT" * 4
565             # make some fake shares
566             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
567             p._privkey = fn.get_privkey()
568             p._encprivkey = fn.get_encprivkey()
569             p._pubkey = fn.get_pubkey()
570             return p._generate_shares(shares_and_ids)
571         d.addCallback(_created)
572         def _generated(res):
573             p = self._p
574             final_shares = p.shares
575             root_hash = p.root_hash
576             self.failUnlessEqual(len(root_hash), 32)
577             self.failUnless(isinstance(final_shares, dict))
578             self.failUnlessEqual(len(final_shares), 10)
579             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
580             for i,sh in final_shares.items():
581                 self.failUnless(isinstance(sh, str))
582                 # feed the share through the unpacker as a sanity-check
583                 pieces = unpack_share(sh)
584                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
585                  pubkey, signature, share_hash_chain, block_hash_tree,
586                  share_data, enc_privkey) = pieces
587                 self.failUnlessEqual(u_seqnum, 3)
588                 self.failUnlessEqual(u_root_hash, root_hash)
589                 self.failUnlessEqual(k, 3)
590                 self.failUnlessEqual(N, 10)
591                 self.failUnlessEqual(segsize, 21)
592                 self.failUnlessEqual(datalen, len(CONTENTS))
593                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
594                 sig_material = struct.pack(">BQ32s16s BBQQ",
595                                            0, p._new_seqnum, root_hash, IV,
596                                            k, N, segsize, datalen)
597                 self.failUnless(p._pubkey.verify(sig_material, signature))
598                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
599                 self.failUnless(isinstance(share_hash_chain, dict))
600                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
601                 for shnum,share_hash in share_hash_chain.items():
602                     self.failUnless(isinstance(shnum, int))
603                     self.failUnless(isinstance(share_hash, str))
604                     self.failUnlessEqual(len(share_hash), 32)
605                 self.failUnless(isinstance(block_hash_tree, list))
606                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
607                 self.failUnlessEqual(IV, "SALT"*4)
608                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
609                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
610         d.addCallback(_generated)
611         return d
612
613     # TODO: when we publish to 20 peers, we should get one share per peer on 10
614     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
615     # when we publish to zero peers, we should get a NotEnoughSharesError
616
617 class Servermap(unittest.TestCase):
618     def setUp(self):
619         # publish a file and create shares, which can then be manipulated
620         # later.
621         num_peers = 20
622         self._client = FakeClient(num_peers)
623         self._storage = self._client._storage
624         d = self._client.create_mutable_file("New contents go here")
625         def _created(node):
626             self._fn = node
627             self._fn2 = self._client.create_node_from_uri(node.get_uri())
628         d.addCallback(_created)
629         return d
630
631     def make_servermap(self, mode=MODE_CHECK, fn=None):
632         if fn is None:
633             fn = self._fn
634         smu = ServermapUpdater(fn, ServerMap(), mode)
635         d = smu.update()
636         return d
637
638     def update_servermap(self, oldmap, mode=MODE_CHECK):
639         smu = ServermapUpdater(self._fn, oldmap, mode)
640         d = smu.update()
641         return d
642
643     def failUnlessOneRecoverable(self, sm, num_shares):
644         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
645         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
646         best = sm.best_recoverable_version()
647         self.failIfEqual(best, None)
648         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
649         self.failUnlessEqual(len(sm.shares_available()), 1)
650         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
651         shnum, peerids = sm.make_sharemap().items()[0]
652         peerid = list(peerids)[0]
653         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
654         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
655         return sm
656
657     def test_basic(self):
658         d = defer.succeed(None)
659         ms = self.make_servermap
660         us = self.update_servermap
661
662         d.addCallback(lambda res: ms(mode=MODE_CHECK))
663         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
664         d.addCallback(lambda res: ms(mode=MODE_WRITE))
665         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
666         d.addCallback(lambda res: ms(mode=MODE_READ))
667         # this more stops at k+epsilon, and epsilon=k, so 6 shares
668         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
669         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
670         # this mode stops at 'k' shares
671         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
672
673         # and can we re-use the same servermap? Note that these are sorted in
674         # increasing order of number of servers queried, since once a server
675         # gets into the servermap, we'll always ask it for an update.
676         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
677         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
678         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
679         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
680         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
681         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
682         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
683         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
684         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
685
686         return d
687
688     def test_fetch_privkey(self):
689         d = defer.succeed(None)
690         # use the sibling filenode (which hasn't been used yet), and make
691         # sure it can fetch the privkey. The file is small, so the privkey
692         # will be fetched on the first (query) pass.
693         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
694         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
695
696         # create a new file, which is large enough to knock the privkey out
697         # of the early part of the file
698         LARGE = "These are Larger contents" * 200 # about 5KB
699         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
700         def _created(large_fn):
701             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
702             return self.make_servermap(MODE_WRITE, large_fn2)
703         d.addCallback(_created)
704         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
705         return d
706
707     def test_mark_bad(self):
708         d = defer.succeed(None)
709         ms = self.make_servermap
710         us = self.update_servermap
711
712         d.addCallback(lambda res: ms(mode=MODE_READ))
713         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
714         def _made_map(sm):
715             v = sm.best_recoverable_version()
716             vm = sm.make_versionmap()
717             shares = list(vm[v])
718             self.failUnlessEqual(len(shares), 6)
719             self._corrupted = set()
720             # mark the first 5 shares as corrupt, then update the servermap.
721             # The map should not have the marked shares it in any more, and
722             # new shares should be found to replace the missing ones.
723             for (shnum, peerid, timestamp) in shares:
724                 if shnum < 5:
725                     self._corrupted.add( (peerid, shnum) )
726                     sm.mark_bad_share(peerid, shnum, "")
727             return self.update_servermap(sm, MODE_WRITE)
728         d.addCallback(_made_map)
729         def _check_map(sm):
730             # this should find all 5 shares that weren't marked bad
731             v = sm.best_recoverable_version()
732             vm = sm.make_versionmap()
733             shares = list(vm[v])
734             for (peerid, shnum) in self._corrupted:
735                 peer_shares = sm.shares_on_peer(peerid)
736                 self.failIf(shnum in peer_shares,
737                             "%d was in %s" % (shnum, peer_shares))
738             self.failUnlessEqual(len(shares), 5)
739         d.addCallback(_check_map)
740         return d
741
742     def failUnlessNoneRecoverable(self, sm):
743         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
744         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
745         best = sm.best_recoverable_version()
746         self.failUnlessEqual(best, None)
747         self.failUnlessEqual(len(sm.shares_available()), 0)
748
749     def test_no_shares(self):
750         self._client._storage._peers = {} # delete all shares
751         ms = self.make_servermap
752         d = defer.succeed(None)
753
754         d.addCallback(lambda res: ms(mode=MODE_CHECK))
755         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
756
757         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
758         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
759
760         d.addCallback(lambda res: ms(mode=MODE_WRITE))
761         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
762
763         d.addCallback(lambda res: ms(mode=MODE_READ))
764         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
765
766         return d
767
768     def failUnlessNotQuiteEnough(self, sm):
769         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
770         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
771         best = sm.best_recoverable_version()
772         self.failUnlessEqual(best, None)
773         self.failUnlessEqual(len(sm.shares_available()), 1)
774         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
775         return sm
776
777     def test_not_quite_enough_shares(self):
778         s = self._client._storage
779         ms = self.make_servermap
780         num_shares = len(s._peers)
781         for peerid in s._peers:
782             s._peers[peerid] = {}
783             num_shares -= 1
784             if num_shares == 2:
785                 break
786         # now there ought to be only two shares left
787         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
788
789         d = defer.succeed(None)
790
791         d.addCallback(lambda res: ms(mode=MODE_CHECK))
792         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
793         d.addCallback(lambda sm:
794                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
795         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
796         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
797         d.addCallback(lambda res: ms(mode=MODE_WRITE))
798         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
799         d.addCallback(lambda res: ms(mode=MODE_READ))
800         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
801
802         return d
803
804
805
806 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
807     def setUp(self):
808         # publish a file and create shares, which can then be manipulated
809         # later.
810         self.CONTENTS = "New contents go here" * 1000
811         num_peers = 20
812         self._client = FakeClient(num_peers)
813         self._storage = self._client._storage
814         d = self._client.create_mutable_file(self.CONTENTS)
815         def _created(node):
816             self._fn = node
817         d.addCallback(_created)
818         return d
819
820     def make_servermap(self, mode=MODE_READ, oldmap=None):
821         if oldmap is None:
822             oldmap = ServerMap()
823         smu = ServermapUpdater(self._fn, oldmap, mode)
824         d = smu.update()
825         return d
826
827     def abbrev_verinfo(self, verinfo):
828         if verinfo is None:
829             return None
830         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
831          offsets_tuple) = verinfo
832         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
833
834     def abbrev_verinfo_dict(self, verinfo_d):
835         output = {}
836         for verinfo,value in verinfo_d.items():
837             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
838              offsets_tuple) = verinfo
839             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
840         return output
841
842     def dump_servermap(self, servermap):
843         print "SERVERMAP", servermap
844         print "RECOVERABLE", [self.abbrev_verinfo(v)
845                               for v in servermap.recoverable_versions()]
846         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
847         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
848
849     def do_download(self, servermap, version=None):
850         if version is None:
851             version = servermap.best_recoverable_version()
852         r = Retrieve(self._fn, servermap, version)
853         return r.download()
854
855     def test_basic(self):
856         d = self.make_servermap()
857         def _do_retrieve(servermap):
858             self._smap = servermap
859             #self.dump_servermap(servermap)
860             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
861             return self.do_download(servermap)
862         d.addCallback(_do_retrieve)
863         def _retrieved(new_contents):
864             self.failUnlessEqual(new_contents, self.CONTENTS)
865         d.addCallback(_retrieved)
866         # we should be able to re-use the same servermap, both with and
867         # without updating it.
868         d.addCallback(lambda res: self.do_download(self._smap))
869         d.addCallback(_retrieved)
870         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
871         d.addCallback(lambda res: self.do_download(self._smap))
872         d.addCallback(_retrieved)
873         # clobbering the pubkey should make the servermap updater re-fetch it
874         def _clobber_pubkey(res):
875             self._fn._pubkey = None
876         d.addCallback(_clobber_pubkey)
877         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
878         d.addCallback(lambda res: self.do_download(self._smap))
879         d.addCallback(_retrieved)
880         return d
881
882     def test_all_shares_vanished(self):
883         d = self.make_servermap()
884         def _remove_shares(servermap):
885             for shares in self._storage._peers.values():
886                 shares.clear()
887             d1 = self.shouldFail(NotEnoughSharesError,
888                                  "test_all_shares_vanished",
889                                  "ran out of peers",
890                                  self.do_download, servermap)
891             return d1
892         d.addCallback(_remove_shares)
893         return d
894
895     def test_no_servers(self):
896         c2 = FakeClient(0)
897         self._fn._client = c2
898         # if there are no servers, then a MODE_READ servermap should come
899         # back empty
900         d = self.make_servermap()
901         def _check_servermap(servermap):
902             self.failUnlessEqual(servermap.best_recoverable_version(), None)
903             self.failIf(servermap.recoverable_versions())
904             self.failIf(servermap.unrecoverable_versions())
905             self.failIf(servermap.all_peers())
906         d.addCallback(_check_servermap)
907         return d
908     test_no_servers.timeout = 15
909
910     def test_no_servers_download(self):
911         c2 = FakeClient(0)
912         self._fn._client = c2
913         d = self.shouldFail(UnrecoverableFileError,
914                             "test_no_servers_download",
915                             "no recoverable versions",
916                             self._fn.download_best_version)
917         def _restore(res):
918             # a failed download that occurs while we aren't connected to
919             # anybody should not prevent a subsequent download from working.
920             # This isn't quite the webapi-driven test that #463 wants, but it
921             # should be close enough.
922             self._fn._client = self._client
923             return self._fn.download_best_version()
924         def _retrieved(new_contents):
925             self.failUnlessEqual(new_contents, self.CONTENTS)
926         d.addCallback(_restore)
927         d.addCallback(_retrieved)
928         return d
929     test_no_servers_download.timeout = 15
930
931     def _test_corrupt_all(self, offset, substring,
932                           should_succeed=False, corrupt_early=True,
933                           failure_checker=None):
934         d = defer.succeed(None)
935         if corrupt_early:
936             d.addCallback(corrupt, self._storage, offset)
937         d.addCallback(lambda res: self.make_servermap())
938         if not corrupt_early:
939             d.addCallback(corrupt, self._storage, offset)
940         def _do_retrieve(servermap):
941             ver = servermap.best_recoverable_version()
942             if ver is None and not should_succeed:
943                 # no recoverable versions == not succeeding. The problem
944                 # should be noted in the servermap's list of problems.
945                 if substring:
946                     allproblems = [str(f) for f in servermap.problems]
947                     self.failUnless(substring in "".join(allproblems))
948                 return servermap
949             if should_succeed:
950                 d1 = self._fn.download_version(servermap, ver)
951                 d1.addCallback(lambda new_contents:
952                                self.failUnlessEqual(new_contents, self.CONTENTS))
953             else:
954                 d1 = self.shouldFail(NotEnoughSharesError,
955                                      "_corrupt_all(offset=%s)" % (offset,),
956                                      substring,
957                                      self._fn.download_version, servermap, ver)
958             if failure_checker:
959                 d1.addCallback(failure_checker)
960             d1.addCallback(lambda res: servermap)
961             return d1
962         d.addCallback(_do_retrieve)
963         return d
964
965     def test_corrupt_all_verbyte(self):
966         # when the version byte is not 0, we hit an assertion error in
967         # unpack_share().
968         d = self._test_corrupt_all(0, "AssertionError")
969         def _check_servermap(servermap):
970             # and the dump should mention the problems
971             s = StringIO()
972             dump = servermap.dump(s).getvalue()
973             self.failUnless("10 PROBLEMS" in dump, dump)
974         d.addCallback(_check_servermap)
975         return d
976
977     def test_corrupt_all_seqnum(self):
978         # a corrupt sequence number will trigger a bad signature
979         return self._test_corrupt_all(1, "signature is invalid")
980
981     def test_corrupt_all_R(self):
982         # a corrupt root hash will trigger a bad signature
983         return self._test_corrupt_all(9, "signature is invalid")
984
985     def test_corrupt_all_IV(self):
986         # a corrupt salt/IV will trigger a bad signature
987         return self._test_corrupt_all(41, "signature is invalid")
988
989     def test_corrupt_all_k(self):
990         # a corrupt 'k' will trigger a bad signature
991         return self._test_corrupt_all(57, "signature is invalid")
992
993     def test_corrupt_all_N(self):
994         # a corrupt 'N' will trigger a bad signature
995         return self._test_corrupt_all(58, "signature is invalid")
996
997     def test_corrupt_all_segsize(self):
998         # a corrupt segsize will trigger a bad signature
999         return self._test_corrupt_all(59, "signature is invalid")
1000
1001     def test_corrupt_all_datalen(self):
1002         # a corrupt data length will trigger a bad signature
1003         return self._test_corrupt_all(67, "signature is invalid")
1004
1005     def test_corrupt_all_pubkey(self):
1006         # a corrupt pubkey won't match the URI's fingerprint. We need to
1007         # remove the pubkey from the filenode, or else it won't bother trying
1008         # to update it.
1009         self._fn._pubkey = None
1010         return self._test_corrupt_all("pubkey",
1011                                       "pubkey doesn't match fingerprint")
1012
1013     def test_corrupt_all_sig(self):
1014         # a corrupt signature is a bad one
1015         # the signature runs from about [543:799], depending upon the length
1016         # of the pubkey
1017         return self._test_corrupt_all("signature", "signature is invalid")
1018
1019     def test_corrupt_all_share_hash_chain_number(self):
1020         # a corrupt share hash chain entry will show up as a bad hash. If we
1021         # mangle the first byte, that will look like a bad hash number,
1022         # causing an IndexError
1023         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1024
1025     def test_corrupt_all_share_hash_chain_hash(self):
1026         # a corrupt share hash chain entry will show up as a bad hash. If we
1027         # mangle a few bytes in, that will look like a bad hash.
1028         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1029
1030     def test_corrupt_all_block_hash_tree(self):
1031         return self._test_corrupt_all("block_hash_tree",
1032                                       "block hash tree failure")
1033
1034     def test_corrupt_all_block(self):
1035         return self._test_corrupt_all("share_data", "block hash tree failure")
1036
1037     def test_corrupt_all_encprivkey(self):
1038         # a corrupted privkey won't even be noticed by the reader, only by a
1039         # writer.
1040         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1041
1042
1043     def test_corrupt_all_seqnum_late(self):
1044         # corrupting the seqnum between mapupdate and retrieve should result
1045         # in NotEnoughSharesError, since each share will look invalid
1046         def _check(res):
1047             f = res[0]
1048             self.failUnless(f.check(NotEnoughSharesError))
1049             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1050         return self._test_corrupt_all(1, "ran out of peers",
1051                                       corrupt_early=False,
1052                                       failure_checker=_check)
1053
1054     def test_corrupt_all_block_hash_tree_late(self):
1055         def _check(res):
1056             f = res[0]
1057             self.failUnless(f.check(NotEnoughSharesError))
1058         return self._test_corrupt_all("block_hash_tree",
1059                                       "block hash tree failure",
1060                                       corrupt_early=False,
1061                                       failure_checker=_check)
1062
1063
1064     def test_corrupt_all_block_late(self):
1065         def _check(res):
1066             f = res[0]
1067             self.failUnless(f.check(NotEnoughSharesError))
1068         return self._test_corrupt_all("share_data", "block hash tree failure",
1069                                       corrupt_early=False,
1070                                       failure_checker=_check)
1071
1072
1073     def test_basic_pubkey_at_end(self):
1074         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1075         # download to succeed but forcing a bunch of retries first. Note that
1076         # this is rather pessimistic: our Retrieve process will throw away
1077         # the whole share if the pubkey is bad, even though the rest of the
1078         # share might be good.
1079
1080         self._fn._pubkey = None
1081         k = self._fn.get_required_shares()
1082         N = self._fn.get_total_shares()
1083         d = defer.succeed(None)
1084         d.addCallback(corrupt, self._storage, "pubkey",
1085                       shnums_to_corrupt=range(0, N-k))
1086         d.addCallback(lambda res: self.make_servermap())
1087         def _do_retrieve(servermap):
1088             self.failUnless(servermap.problems)
1089             self.failUnless("pubkey doesn't match fingerprint"
1090                             in str(servermap.problems[0]))
1091             ver = servermap.best_recoverable_version()
1092             r = Retrieve(self._fn, servermap, ver)
1093             return r.download()
1094         d.addCallback(_do_retrieve)
1095         d.addCallback(lambda new_contents:
1096                       self.failUnlessEqual(new_contents, self.CONTENTS))
1097         return d
1098
1099     def test_corrupt_some(self):
1100         # corrupt the data of first five shares (so the servermap thinks
1101         # they're good but retrieve marks them as bad), so that the
1102         # MODE_READ set of 6 will be insufficient, forcing node.download to
1103         # retry with more servers.
1104         corrupt(None, self._storage, "share_data", range(5))
1105         d = self.make_servermap()
1106         def _do_retrieve(servermap):
1107             ver = servermap.best_recoverable_version()
1108             self.failUnless(ver)
1109             return self._fn.download_best_version()
1110         d.addCallback(_do_retrieve)
1111         d.addCallback(lambda new_contents:
1112                       self.failUnlessEqual(new_contents, self.CONTENTS))
1113         return d
1114
1115     def test_download_fails(self):
1116         corrupt(None, self._storage, "signature")
1117         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1118                             "no recoverable versions",
1119                             self._fn.download_best_version)
1120         return d
1121
1122
1123 class CheckerMixin:
1124     def check_good(self, r, where):
1125         self.failUnless(r.healthy, where)
1126         self.failIf(r.problems, where)
1127         return r
1128
1129     def check_bad(self, r, where):
1130         self.failIf(r.healthy, where)
1131         return r
1132
1133     def check_expected_failure(self, r, expected_exception, substring, where):
1134         for (peerid, storage_index, shnum, f) in r.problems:
1135             if f.check(expected_exception):
1136                 self.failUnless(substring in str(f),
1137                                 "%s: substring '%s' not in '%s'" %
1138                                 (where, substring, str(f)))
1139                 return
1140         self.fail("%s: didn't see expected exception %s in problems %s" %
1141                   (where, expected_exception, r.problems))
1142
1143
1144 class Checker(unittest.TestCase, CheckerMixin):
1145     def setUp(self):
1146         # publish a file and create shares, which can then be manipulated
1147         # later.
1148         self.CONTENTS = "New contents go here" * 1000
1149         num_peers = 20
1150         self._client = FakeClient(num_peers)
1151         self._storage = self._client._storage
1152         d = self._client.create_mutable_file(self.CONTENTS)
1153         def _created(node):
1154             self._fn = node
1155         d.addCallback(_created)
1156         return d
1157
1158
1159     def test_check_good(self):
1160         d = self._fn.check()
1161         d.addCallback(self.check_good, "test_check_good")
1162         return d
1163
1164     def test_check_no_shares(self):
1165         for shares in self._storage._peers.values():
1166             shares.clear()
1167         d = self._fn.check()
1168         d.addCallback(self.check_bad, "test_check_no_shares")
1169         return d
1170
1171     def test_check_not_enough_shares(self):
1172         for shares in self._storage._peers.values():
1173             for shnum in shares.keys():
1174                 if shnum > 0:
1175                     del shares[shnum]
1176         d = self._fn.check()
1177         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1178         return d
1179
1180     def test_check_all_bad_sig(self):
1181         corrupt(None, self._storage, 1) # bad sig
1182         d = self._fn.check()
1183         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1184         return d
1185
1186     def test_check_all_bad_blocks(self):
1187         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1188         # the Checker won't notice this.. it doesn't look at actual data
1189         d = self._fn.check()
1190         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1191         return d
1192
1193     def test_verify_good(self):
1194         d = self._fn.check(verify=True)
1195         d.addCallback(self.check_good, "test_verify_good")
1196         return d
1197
1198     def test_verify_all_bad_sig(self):
1199         corrupt(None, self._storage, 1) # bad sig
1200         d = self._fn.check(verify=True)
1201         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1202         return d
1203
1204     def test_verify_one_bad_sig(self):
1205         corrupt(None, self._storage, 1, [9]) # bad sig
1206         d = self._fn.check(verify=True)
1207         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1208         return d
1209
1210     def test_verify_one_bad_block(self):
1211         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1212         # the Verifier *will* notice this, since it examines every byte
1213         d = self._fn.check(verify=True)
1214         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1215         d.addCallback(self.check_expected_failure,
1216                       CorruptShareError, "block hash tree failure",
1217                       "test_verify_one_bad_block")
1218         return d
1219
1220     def test_verify_one_bad_sharehash(self):
1221         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1222         d = self._fn.check(verify=True)
1223         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1224         d.addCallback(self.check_expected_failure,
1225                       CorruptShareError, "corrupt hashes",
1226                       "test_verify_one_bad_sharehash")
1227         return d
1228
1229     def test_verify_one_bad_encprivkey(self):
1230         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1231         d = self._fn.check(verify=True)
1232         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1233         d.addCallback(self.check_expected_failure,
1234                       CorruptShareError, "invalid privkey",
1235                       "test_verify_one_bad_encprivkey")
1236         return d
1237
1238     def test_verify_one_bad_encprivkey_uncheckable(self):
1239         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1240         readonly_fn = self._fn.get_readonly()
1241         # a read-only node has no way to validate the privkey
1242         d = readonly_fn.check(verify=True)
1243         d.addCallback(self.check_good,
1244                       "test_verify_one_bad_encprivkey_uncheckable")
1245         return d
1246
1247
1248 class MultipleEncodings(unittest.TestCase):
1249     def setUp(self):
1250         self.CONTENTS = "New contents go here"
1251         num_peers = 20
1252         self._client = FakeClient(num_peers)
1253         self._storage = self._client._storage
1254         d = self._client.create_mutable_file(self.CONTENTS)
1255         def _created(node):
1256             self._fn = node
1257         d.addCallback(_created)
1258         return d
1259
1260     def _encode(self, k, n, data):
1261         # encode 'data' into a peerid->shares dict.
1262
1263         fn2 = FastMutableFileNode(self._client)
1264         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1265         # and _fingerprint
1266         fn = self._fn
1267         fn2.init_from_uri(fn.get_uri())
1268         # then we copy over other fields that are normally fetched from the
1269         # existing shares
1270         fn2._pubkey = fn._pubkey
1271         fn2._privkey = fn._privkey
1272         fn2._encprivkey = fn._encprivkey
1273         # and set the encoding parameters to something completely different
1274         fn2._required_shares = k
1275         fn2._total_shares = n
1276
1277         s = self._client._storage
1278         s._peers = {} # clear existing storage
1279         p2 = Publish(fn2, None)
1280         d = p2.publish(data)
1281         def _published(res):
1282             shares = s._peers
1283             s._peers = {}
1284             return shares
1285         d.addCallback(_published)
1286         return d
1287
1288     def make_servermap(self, mode=MODE_READ, oldmap=None):
1289         if oldmap is None:
1290             oldmap = ServerMap()
1291         smu = ServermapUpdater(self._fn, oldmap, mode)
1292         d = smu.update()
1293         return d
1294
1295     def test_multiple_encodings(self):
1296         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1297         # then mix up the shares, to make sure that download survives seeing
1298         # a variety of encodings. This is actually kind of tricky to set up.
1299
1300         contents1 = "Contents for encoding 1 (3-of-10) go here"
1301         contents2 = "Contents for encoding 2 (4-of-9) go here"
1302         contents3 = "Contents for encoding 3 (4-of-7) go here"
1303
1304         # we make a retrieval object that doesn't know what encoding
1305         # parameters to use
1306         fn3 = FastMutableFileNode(self._client)
1307         fn3.init_from_uri(self._fn.get_uri())
1308
1309         # now we upload a file through fn1, and grab its shares
1310         d = self._encode(3, 10, contents1)
1311         def _encoded_1(shares):
1312             self._shares1 = shares
1313         d.addCallback(_encoded_1)
1314         d.addCallback(lambda res: self._encode(4, 9, contents2))
1315         def _encoded_2(shares):
1316             self._shares2 = shares
1317         d.addCallback(_encoded_2)
1318         d.addCallback(lambda res: self._encode(4, 7, contents3))
1319         def _encoded_3(shares):
1320             self._shares3 = shares
1321         d.addCallback(_encoded_3)
1322
1323         def _merge(res):
1324             log.msg("merging sharelists")
1325             # we merge the shares from the two sets, leaving each shnum in
1326             # its original location, but using a share from set1 or set2
1327             # according to the following sequence:
1328             #
1329             #  4-of-9  a  s2
1330             #  4-of-9  b  s2
1331             #  4-of-7  c   s3
1332             #  4-of-9  d  s2
1333             #  3-of-9  e s1
1334             #  3-of-9  f s1
1335             #  3-of-9  g s1
1336             #  4-of-9  h  s2
1337             #
1338             # so that neither form can be recovered until fetch [f], at which
1339             # point version-s1 (the 3-of-10 form) should be recoverable. If
1340             # the implementation latches on to the first version it sees,
1341             # then s2 will be recoverable at fetch [g].
1342
1343             # Later, when we implement code that handles multiple versions,
1344             # we can use this framework to assert that all recoverable
1345             # versions are retrieved, and test that 'epsilon' does its job
1346
1347             places = [2, 2, 3, 2, 1, 1, 1, 2]
1348
1349             sharemap = {}
1350
1351             for i,peerid in enumerate(self._client._peerids):
1352                 peerid_s = shortnodeid_b2a(peerid)
1353                 for shnum in self._shares1.get(peerid, {}):
1354                     if shnum < len(places):
1355                         which = places[shnum]
1356                     else:
1357                         which = "x"
1358                     self._client._storage._peers[peerid] = peers = {}
1359                     in_1 = shnum in self._shares1[peerid]
1360                     in_2 = shnum in self._shares2.get(peerid, {})
1361                     in_3 = shnum in self._shares3.get(peerid, {})
1362                     #print peerid_s, shnum, which, in_1, in_2, in_3
1363                     if which == 1:
1364                         if in_1:
1365                             peers[shnum] = self._shares1[peerid][shnum]
1366                             sharemap[shnum] = peerid
1367                     elif which == 2:
1368                         if in_2:
1369                             peers[shnum] = self._shares2[peerid][shnum]
1370                             sharemap[shnum] = peerid
1371                     elif which == 3:
1372                         if in_3:
1373                             peers[shnum] = self._shares3[peerid][shnum]
1374                             sharemap[shnum] = peerid
1375
1376             # we don't bother placing any other shares
1377             # now sort the sequence so that share 0 is returned first
1378             new_sequence = [sharemap[shnum]
1379                             for shnum in sorted(sharemap.keys())]
1380             self._client._storage._sequence = new_sequence
1381             log.msg("merge done")
1382         d.addCallback(_merge)
1383         d.addCallback(lambda res: fn3.download_best_version())
1384         def _retrieved(new_contents):
1385             # the current specified behavior is "first version recoverable"
1386             self.failUnlessEqual(new_contents, contents1)
1387         d.addCallback(_retrieved)
1388         return d
1389
1390 class MultipleVersions(unittest.TestCase, CheckerMixin):
1391     def setUp(self):
1392         self.CONTENTS = ["Contents 0",
1393                          "Contents 1",
1394                          "Contents 2",
1395                          "Contents 3a",
1396                          "Contents 3b"]
1397         self._copied_shares = {}
1398         num_peers = 20
1399         self._client = FakeClient(num_peers)
1400         self._storage = self._client._storage
1401         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
1402         def _created(node):
1403             self._fn = node
1404             # now create multiple versions of the same file, and accumulate
1405             # their shares, so we can mix and match them later.
1406             d = defer.succeed(None)
1407             d.addCallback(self._copy_shares, 0)
1408             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
1409             d.addCallback(self._copy_shares, 1)
1410             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
1411             d.addCallback(self._copy_shares, 2)
1412             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
1413             d.addCallback(self._copy_shares, 3)
1414             # now we replace all the shares with version s3, and upload a new
1415             # version to get s4b.
1416             rollback = dict([(i,2) for i in range(10)])
1417             d.addCallback(lambda res: self._set_versions(rollback))
1418             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
1419             d.addCallback(self._copy_shares, 4)
1420             # we leave the storage in state 4
1421             return d
1422         d.addCallback(_created)
1423         return d
1424
1425     def _copy_shares(self, ignored, index):
1426         shares = self._client._storage._peers
1427         # we need a deep copy
1428         new_shares = {}
1429         for peerid in shares:
1430             new_shares[peerid] = {}
1431             for shnum in shares[peerid]:
1432                 new_shares[peerid][shnum] = shares[peerid][shnum]
1433         self._copied_shares[index] = new_shares
1434
1435     def _set_versions(self, versionmap):
1436         # versionmap maps shnums to which version (0,1,2,3,4) we want the
1437         # share to be at. Any shnum which is left out of the map will stay at
1438         # its current version.
1439         shares = self._client._storage._peers
1440         oldshares = self._copied_shares
1441         for peerid in shares:
1442             for shnum in shares[peerid]:
1443                 if shnum in versionmap:
1444                     index = versionmap[shnum]
1445                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
1446
1447     def test_multiple_versions(self):
1448         # if we see a mix of versions in the grid, download_best_version
1449         # should get the latest one
1450         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1451         d = self._fn.download_best_version()
1452         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1453         # and the checker should report problems
1454         d.addCallback(lambda res: self._fn.check())
1455         d.addCallback(self.check_bad, "test_multiple_versions")
1456
1457         # but if everything is at version 2, that's what we should download
1458         d.addCallback(lambda res:
1459                       self._set_versions(dict([(i,2) for i in range(10)])))
1460         d.addCallback(lambda res: self._fn.download_best_version())
1461         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1462         # if exactly one share is at version 3, we should still get v2
1463         d.addCallback(lambda res:
1464                       self._set_versions({0:3}))
1465         d.addCallback(lambda res: self._fn.download_best_version())
1466         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1467         # but the servermap should see the unrecoverable version. This
1468         # depends upon the single newer share being queried early.
1469         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1470         def _check_smap(smap):
1471             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1472             newer = smap.unrecoverable_newer_versions()
1473             self.failUnlessEqual(len(newer), 1)
1474             verinfo, health = newer.items()[0]
1475             self.failUnlessEqual(verinfo[0], 4)
1476             self.failUnlessEqual(health, (1,3))
1477             self.failIf(smap.needs_merge())
1478         d.addCallback(_check_smap)
1479         # if we have a mix of two parallel versions (s4a and s4b), we could
1480         # recover either
1481         d.addCallback(lambda res:
1482                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1483                                           1:4,3:4,5:4,7:4,9:4}))
1484         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1485         def _check_smap_mixed(smap):
1486             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1487             newer = smap.unrecoverable_newer_versions()
1488             self.failUnlessEqual(len(newer), 0)
1489             self.failUnless(smap.needs_merge())
1490         d.addCallback(_check_smap_mixed)
1491         d.addCallback(lambda res: self._fn.download_best_version())
1492         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1493                                                   res == self.CONTENTS[4]))
1494         return d
1495
1496     def test_replace(self):
1497         # if we see a mix of versions in the grid, we should be able to
1498         # replace them all with a newer version
1499
1500         # if exactly one share is at version 3, we should download (and
1501         # replace) v2, and the result should be v4. Note that the index we
1502         # give to _set_versions is different than the sequence number.
1503         target = dict([(i,2) for i in range(10)]) # seqnum3
1504         target[0] = 3 # seqnum4
1505         self._set_versions(target)
1506
1507         def _modify(oldversion):
1508             return oldversion + " modified"
1509         d = self._fn.modify(_modify)
1510         d.addCallback(lambda res: self._fn.download_best_version())
1511         expected = self.CONTENTS[2] + " modified"
1512         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1513         # and the servermap should indicate that the outlier was replaced too
1514         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1515         def _check_smap(smap):
1516             self.failUnlessEqual(smap.highest_seqnum(), 5)
1517             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1518             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1519         d.addCallback(_check_smap)
1520         return d
1521
1522
1523 class Utils(unittest.TestCase):
1524     def test_dict_of_sets(self):
1525         ds = DictOfSets()
1526         ds.add(1, "a")
1527         ds.add(2, "b")
1528         ds.add(2, "b")
1529         ds.add(2, "c")
1530         self.failUnlessEqual(ds[1], set(["a"]))
1531         self.failUnlessEqual(ds[2], set(["b", "c"]))
1532         ds.discard(3, "d") # should not raise an exception
1533         ds.discard(2, "b")
1534         self.failUnlessEqual(ds[2], set(["c"]))
1535         ds.discard(2, "c")
1536         self.failIf(2 in ds)
1537
1538     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1539         # we compare this against sets of integers
1540         x = set(range(x_start, x_start+x_length))
1541         y = set(range(y_start, y_start+y_length))
1542         should_be_inside = x.issubset(y)
1543         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1544                                                          y_start, y_length),
1545                              str((x_start, x_length, y_start, y_length)))
1546
1547     def test_cache_inside(self):
1548         c = ResponseCache()
1549         x_start = 10
1550         x_length = 5
1551         for y_start in range(8, 17):
1552             for y_length in range(8):
1553                 self._do_inside(c, x_start, x_length, y_start, y_length)
1554
1555     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1556         # we compare this against sets of integers
1557         x = set(range(x_start, x_start+x_length))
1558         y = set(range(y_start, y_start+y_length))
1559         overlap = bool(x.intersection(y))
1560         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1561                                                       y_start, y_length),
1562                              str((x_start, x_length, y_start, y_length)))
1563
1564     def test_cache_overlap(self):
1565         c = ResponseCache()
1566         x_start = 10
1567         x_length = 5
1568         for y_start in range(8, 17):
1569             for y_length in range(8):
1570                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1571
1572     def test_cache(self):
1573         c = ResponseCache()
1574         # xdata = base62.b2a(os.urandom(100))[:100]
1575         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1576         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1577         nope = (None, None)
1578         c.add("v1", 1, 0, xdata, "time0")
1579         c.add("v1", 1, 2000, ydata, "time1")
1580         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1581         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1582         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1583         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1584         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1585         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1586         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1587         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1588         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1589         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1590         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1591         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1592         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1593         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1594         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1595         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1596         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1597         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1598
1599         # optional: join fragments
1600         c = ResponseCache()
1601         c.add("v1", 1, 0, xdata[:10], "time0")
1602         c.add("v1", 1, 10, xdata[10:20], "time1")
1603         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1604
1605 class Exceptions(unittest.TestCase):
1606     def test_repr(self):
1607         nmde = NeedMoreDataError(100, 50, 100)
1608         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1609         ucwe = UncoordinatedWriteError()
1610         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1611
1612 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1613 # instances which always succeed. So we need a less-fake one.
1614
1615 class IntentionalError(Exception):
1616     pass
1617
1618 class LocalWrapper:
1619     def __init__(self, original):
1620         self.original = original
1621         self.broken = False
1622         self.post_call_notifier = None
1623     def callRemote(self, methname, *args, **kwargs):
1624         def _call():
1625             if self.broken:
1626                 raise IntentionalError("I was asked to break")
1627             meth = getattr(self.original, "remote_" + methname)
1628             return meth(*args, **kwargs)
1629         d = fireEventually()
1630         d.addCallback(lambda res: _call())
1631         if self.post_call_notifier:
1632             d.addCallback(self.post_call_notifier, methname)
1633         return d
1634
1635 class LessFakeClient(FakeClient):
1636
1637     def __init__(self, basedir, num_peers=10):
1638         self._num_peers = num_peers
1639         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1640                          for i in range(self._num_peers)]
1641         self._connections = {}
1642         for peerid in self._peerids:
1643             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1644             make_dirs(peerdir)
1645             ss = storage.StorageServer(peerdir)
1646             ss.setNodeID(peerid)
1647             lw = LocalWrapper(ss)
1648             self._connections[peerid] = lw
1649         self.nodeid = "fakenodeid"
1650
1651
1652 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1653     def test_publish_surprise(self):
1654         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1655         self.client = LessFakeClient(basedir)
1656         d = self.client.create_mutable_file("contents 1")
1657         def _created(n):
1658             d = defer.succeed(None)
1659             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1660             def _got_smap1(smap):
1661                 # stash the old state of the file
1662                 self.old_map = smap
1663             d.addCallback(_got_smap1)
1664             # then modify the file, leaving the old map untouched
1665             d.addCallback(lambda res: log.msg("starting winning write"))
1666             d.addCallback(lambda res: n.overwrite("contents 2"))
1667             # now attempt to modify the file with the old servermap. This
1668             # will look just like an uncoordinated write, in which every
1669             # single share got updated between our mapupdate and our publish
1670             d.addCallback(lambda res: log.msg("starting doomed write"))
1671             d.addCallback(lambda res:
1672                           self.shouldFail(UncoordinatedWriteError,
1673                                           "test_publish_surprise", None,
1674                                           n.upload,
1675                                           "contents 2a", self.old_map))
1676             return d
1677         d.addCallback(_created)
1678         return d
1679
1680     def test_retrieve_surprise(self):
1681         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1682         self.client = LessFakeClient(basedir)
1683         d = self.client.create_mutable_file("contents 1")
1684         def _created(n):
1685             d = defer.succeed(None)
1686             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1687             def _got_smap1(smap):
1688                 # stash the old state of the file
1689                 self.old_map = smap
1690             d.addCallback(_got_smap1)
1691             # then modify the file, leaving the old map untouched
1692             d.addCallback(lambda res: log.msg("starting winning write"))
1693             d.addCallback(lambda res: n.overwrite("contents 2"))
1694             # now attempt to retrieve the old version with the old servermap.
1695             # This will look like someone has changed the file since we
1696             # updated the servermap.
1697             d.addCallback(lambda res: n._cache._clear())
1698             d.addCallback(lambda res: log.msg("starting doomed read"))
1699             d.addCallback(lambda res:
1700                           self.shouldFail(NotEnoughSharesError,
1701                                           "test_retrieve_surprise",
1702                                           "ran out of peers: have 0 shares (k=3)",
1703                                           n.download_version,
1704                                           self.old_map,
1705                                           self.old_map.best_recoverable_version(),
1706                                           ))
1707             return d
1708         d.addCallback(_created)
1709         return d
1710
1711     def test_unexpected_shares(self):
1712         # upload the file, take a servermap, shut down one of the servers,
1713         # upload it again (causing shares to appear on a new server), then
1714         # upload using the old servermap. The last upload should fail with an
1715         # UncoordinatedWriteError, because of the shares that didn't appear
1716         # in the servermap.
1717         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1718         self.client = LessFakeClient(basedir)
1719         d = self.client.create_mutable_file("contents 1")
1720         def _created(n):
1721             d = defer.succeed(None)
1722             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1723             def _got_smap1(smap):
1724                 # stash the old state of the file
1725                 self.old_map = smap
1726                 # now shut down one of the servers
1727                 peer0 = list(smap.make_sharemap()[0])[0]
1728                 self.client._connections.pop(peer0)
1729                 # then modify the file, leaving the old map untouched
1730                 log.msg("starting winning write")
1731                 return n.overwrite("contents 2")
1732             d.addCallback(_got_smap1)
1733             # now attempt to modify the file with the old servermap. This
1734             # will look just like an uncoordinated write, in which every
1735             # single share got updated between our mapupdate and our publish
1736             d.addCallback(lambda res: log.msg("starting doomed write"))
1737             d.addCallback(lambda res:
1738                           self.shouldFail(UncoordinatedWriteError,
1739                                           "test_surprise", None,
1740                                           n.upload,
1741                                           "contents 2a", self.old_map))
1742             return d
1743         d.addCallback(_created)
1744         return d
1745
1746     def test_bad_server(self):
1747         # Break one server, then create the file: the initial publish should
1748         # complete with an alternate server. Breaking a second server should
1749         # not prevent an update from succeeding either.
1750         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1751         self.client = LessFakeClient(basedir, 20)
1752         # to make sure that one of the initial peers is broken, we have to
1753         # get creative. We create the keys, so we can figure out the storage
1754         # index, but we hold off on doing the initial publish until we've
1755         # broken the server on which the first share wants to be stored.
1756         n = FastMutableFileNode(self.client)
1757         d = defer.succeed(None)
1758         d.addCallback(n._generate_pubprivkeys)
1759         d.addCallback(n._generated)
1760         def _break_peer0(res):
1761             si = n.get_storage_index()
1762             peerlist = self.client.get_permuted_peers("storage", si)
1763             peerid0, connection0 = peerlist[0]
1764             peerid1, connection1 = peerlist[1]
1765             connection0.broken = True
1766             self.connection1 = connection1
1767         d.addCallback(_break_peer0)
1768         # now let the initial publish finally happen
1769         d.addCallback(lambda res: n._upload("contents 1", None))
1770         # that ought to work
1771         d.addCallback(lambda res: n.download_best_version())
1772         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1773         # now break the second peer
1774         def _break_peer1(res):
1775             self.connection1.broken = True
1776         d.addCallback(_break_peer1)
1777         d.addCallback(lambda res: n.overwrite("contents 2"))
1778         # that ought to work too
1779         d.addCallback(lambda res: n.download_best_version())
1780         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1781         return d
1782
1783     def test_publish_all_servers_bad(self):
1784         # Break all servers: the publish should fail
1785         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1786         self.client = LessFakeClient(basedir, 20)
1787         for connection in self.client._connections.values():
1788             connection.broken = True
1789         d = self.shouldFail(NotEnoughServersError,
1790                             "test_publish_all_servers_bad",
1791                             "Ran out of non-bad servers",
1792                             self.client.create_mutable_file, "contents")
1793         return d
1794
1795     def test_publish_no_servers(self):
1796         # no servers at all: the publish should fail
1797         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1798         self.client = LessFakeClient(basedir, 0)
1799         d = self.shouldFail(NotEnoughServersError,
1800                             "test_publish_no_servers",
1801                             "Ran out of non-bad servers",
1802                             self.client.create_mutable_file, "contents")
1803         return d
1804     test_publish_no_servers.timeout = 30
1805
1806
1807     def test_privkey_query_error(self):
1808         # when a servermap is updated with MODE_WRITE, it tries to get the
1809         # privkey. Something might go wrong during this query attempt.
1810         self.client = FakeClient(20)
1811         # we need some contents that are large enough to push the privkey out
1812         # of the early part of the file
1813         LARGE = "These are Larger contents" * 200 # about 5KB
1814         d = self.client.create_mutable_file(LARGE)
1815         def _created(n):
1816             self.uri = n.get_uri()
1817             self.n2 = self.client.create_node_from_uri(self.uri)
1818             # we start by doing a map update to figure out which is the first
1819             # server.
1820             return n.get_servermap(MODE_WRITE)
1821         d.addCallback(_created)
1822         d.addCallback(lambda res: fireEventually(res))
1823         def _got_smap1(smap):
1824             peer0 = list(smap.make_sharemap()[0])[0]
1825             # we tell the server to respond to this peer first, so that it
1826             # will be asked for the privkey first
1827             self.client._storage._sequence = [peer0]
1828             # now we make the peer fail their second query
1829             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1830         d.addCallback(_got_smap1)
1831         # now we update a servermap from a new node (which doesn't have the
1832         # privkey yet, forcing it to use a separate privkey query). Each
1833         # query response will trigger a privkey query, and since we're using
1834         # _sequence to make the peer0 response come back first, we'll send it
1835         # a privkey query first, and _sequence will again ensure that the
1836         # peer0 query will also come back before the others, and then
1837         # _special_answers will make sure that the query raises an exception.
1838         # The whole point of these hijinks is to exercise the code in
1839         # _privkey_query_failed. Note that the map-update will succeed, since
1840         # we'll just get a copy from one of the other shares.
1841         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1842         # Using FakeStorage._sequence means there will be read requests still
1843         # floating around.. wait for them to retire
1844         def _cancel_timer(res):
1845             if self.client._storage._pending_timer:
1846                 self.client._storage._pending_timer.cancel()
1847             return res
1848         d.addBoth(_cancel_timer)
1849         return d
1850
1851     def test_privkey_query_missing(self):
1852         # like test_privkey_query_error, but the shares are deleted by the
1853         # second query, instead of raising an exception.
1854         self.client = FakeClient(20)
1855         LARGE = "These are Larger contents" * 200 # about 5KB
1856         d = self.client.create_mutable_file(LARGE)
1857         def _created(n):
1858             self.uri = n.get_uri()
1859             self.n2 = self.client.create_node_from_uri(self.uri)
1860             return n.get_servermap(MODE_WRITE)
1861         d.addCallback(_created)
1862         d.addCallback(lambda res: fireEventually(res))
1863         def _got_smap1(smap):
1864             peer0 = list(smap.make_sharemap()[0])[0]
1865             self.client._storage._sequence = [peer0]
1866             self.client._storage._special_answers[peer0] = ["normal", "none"]
1867         d.addCallback(_got_smap1)
1868         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1869         def _cancel_timer(res):
1870             if self.client._storage._pending_timer:
1871                 self.client._storage._pending_timer.cancel()
1872             return res
1873         d.addBoth(_cancel_timer)
1874         return d