]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_mutable: factor out common setup code
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, storage
8 from allmydata.immutable import download
9 from allmydata.immutable.encode import NotEnoughSharesError
10 from allmydata.util import base32, testutil, idlib
11 from allmydata.util.idlib import shortnodeid_b2a
12 from allmydata.util.hashutil import tagged_hash
13 from allmydata.util.fileutil import make_dirs
14 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
15      FileTooLargeError, IRepairResults
16 from foolscap.eventual import eventually, fireEventually
17 from foolscap.logging import log
18 import sha
19
20 from allmydata.mutable.node import MutableFileNode, BackoffAgent
21 from allmydata.mutable.common import DictOfSets, ResponseCache, \
22      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
23      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
24      NotEnoughServersError, CorruptShareError
25 from allmydata.mutable.retrieve import Retrieve
26 from allmydata.mutable.publish import Publish
27 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
28 from allmydata.mutable.layout import unpack_header, unpack_share
29
30 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
31 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
32 # of this.
33
34 class FastMutableFileNode(MutableFileNode):
35     SIGNATURE_KEY_SIZE = 522
36
37 # this "FakeStorage" exists to put the share data in RAM and avoid using real
38 # network connections, both to speed up the tests and to reduce the amount of
39 # non-mutable.py code being exercised.
40
41 class FakeStorage:
42     # this class replaces the collection of storage servers, allowing the
43     # tests to examine and manipulate the published shares. It also lets us
44     # control the order in which read queries are answered, to exercise more
45     # of the error-handling code in Retrieve .
46     #
47     # Note that we ignore the storage index: this FakeStorage instance can
48     # only be used for a single storage index.
49
50
51     def __init__(self):
52         self._peers = {}
53         # _sequence is used to cause the responses to occur in a specific
54         # order. If it is in use, then we will defer queries instead of
55         # answering them right away, accumulating the Deferreds in a dict. We
56         # don't know exactly how many queries we'll get, so exactly one
57         # second after the first query arrives, we will release them all (in
58         # order).
59         self._sequence = None
60         self._pending = {}
61         self._pending_timer = None
62         self._special_answers = {}
63
64     def read(self, peerid, storage_index):
65         shares = self._peers.get(peerid, {})
66         if self._special_answers.get(peerid, []):
67             mode = self._special_answers[peerid].pop(0)
68             if mode == "fail":
69                 shares = failure.Failure(IntentionalError())
70             elif mode == "none":
71                 shares = {}
72             elif mode == "normal":
73                 pass
74         if self._sequence is None:
75             return defer.succeed(shares)
76         d = defer.Deferred()
77         if not self._pending:
78             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
79         self._pending[peerid] = (d, shares)
80         return d
81
82     def _fire_readers(self):
83         self._pending_timer = None
84         pending = self._pending
85         self._pending = {}
86         extra = []
87         for peerid in self._sequence:
88             if peerid in pending:
89                 d, shares = pending.pop(peerid)
90                 eventually(d.callback, shares)
91         for (d, shares) in pending.values():
92             eventually(d.callback, shares)
93
94     def write(self, peerid, storage_index, shnum, offset, data):
95         if peerid not in self._peers:
96             self._peers[peerid] = {}
97         shares = self._peers[peerid]
98         f = StringIO()
99         f.write(shares.get(shnum, ""))
100         f.seek(offset)
101         f.write(data)
102         shares[shnum] = f.getvalue()
103
104
105 class FakeStorageServer:
106     def __init__(self, peerid, storage):
107         self.peerid = peerid
108         self.storage = storage
109         self.queries = 0
110     def callRemote(self, methname, *args, **kwargs):
111         def _call():
112             meth = getattr(self, methname)
113             return meth(*args, **kwargs)
114         d = fireEventually()
115         d.addCallback(lambda res: _call())
116         return d
117
118     def slot_readv(self, storage_index, shnums, readv):
119         d = self.storage.read(self.peerid, storage_index)
120         def _read(shares):
121             response = {}
122             for shnum in shares:
123                 if shnums and shnum not in shnums:
124                     continue
125                 vector = response[shnum] = []
126                 for (offset, length) in readv:
127                     assert isinstance(offset, (int, long)), offset
128                     assert isinstance(length, (int, long)), length
129                     vector.append(shares[shnum][offset:offset+length])
130             return response
131         d.addCallback(_read)
132         return d
133
134     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
135                                         tw_vectors, read_vector):
136         # always-pass: parrot the test vectors back to them.
137         readv = {}
138         for shnum, (testv, writev, new_length) in tw_vectors.items():
139             for (offset, length, op, specimen) in testv:
140                 assert op in ("le", "eq", "ge")
141             # TODO: this isn't right, the read is controlled by read_vector,
142             # not by testv
143             readv[shnum] = [ specimen
144                              for (offset, length, op, specimen)
145                              in testv ]
146             for (offset, data) in writev:
147                 self.storage.write(self.peerid, storage_index, shnum,
148                                    offset, data)
149         answer = (True, readv)
150         return fireEventually(answer)
151
152
153 # our "FakeClient" has just enough functionality of the real Client to let
154 # the tests run.
155
156 class FakeClient:
157     mutable_file_node_class = FastMutableFileNode
158
159     def __init__(self, num_peers=10):
160         self._storage = FakeStorage()
161         self._num_peers = num_peers
162         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
163                          for i in range(self._num_peers)]
164         self._connections = dict([(peerid, FakeStorageServer(peerid,
165                                                              self._storage))
166                                   for peerid in self._peerids])
167         self.nodeid = "fakenodeid"
168
169     def log(self, msg, **kw):
170         return log.msg(msg, **kw)
171
172     def get_renewal_secret(self):
173         return "I hereby permit you to renew my files"
174     def get_cancel_secret(self):
175         return "I hereby permit you to cancel my leases"
176
177     def create_mutable_file(self, contents=""):
178         n = self.mutable_file_node_class(self)
179         d = n.create(contents)
180         d.addCallback(lambda res: n)
181         return d
182
183     def notify_retrieve(self, r):
184         pass
185     def notify_publish(self, p, size):
186         pass
187     def notify_mapupdate(self, u):
188         pass
189
190     def create_node_from_uri(self, u):
191         u = IURI(u)
192         assert IMutableFileURI.providedBy(u), u
193         res = self.mutable_file_node_class(self).init_from_uri(u)
194         return res
195
196     def get_permuted_peers(self, service_name, key):
197         """
198         @return: list of (peerid, connection,)
199         """
200         results = []
201         for (peerid, connection) in self._connections.items():
202             assert isinstance(peerid, str)
203             permuted = sha.new(key + peerid).digest()
204             results.append((permuted, peerid, connection))
205         results.sort()
206         results = [ (r[1],r[2]) for r in results]
207         return results
208
209     def upload(self, uploadable):
210         assert IUploadable.providedBy(uploadable)
211         d = uploadable.get_size()
212         d.addCallback(lambda length: uploadable.read(length))
213         #d.addCallback(self.create_mutable_file)
214         def _got_data(datav):
215             data = "".join(datav)
216             #newnode = FastMutableFileNode(self)
217             return uri.LiteralFileURI(data)
218         d.addCallback(_got_data)
219         return d
220
221
222 def flip_bit(original, byte_offset):
223     return (original[:byte_offset] +
224             chr(ord(original[byte_offset]) ^ 0x01) +
225             original[byte_offset+1:])
226
227 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
228     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
229     # list of shnums to corrupt.
230     for peerid in s._peers:
231         shares = s._peers[peerid]
232         for shnum in shares:
233             if (shnums_to_corrupt is not None
234                 and shnum not in shnums_to_corrupt):
235                 continue
236             data = shares[shnum]
237             (version,
238              seqnum,
239              root_hash,
240              IV,
241              k, N, segsize, datalen,
242              o) = unpack_header(data)
243             if isinstance(offset, tuple):
244                 offset1, offset2 = offset
245             else:
246                 offset1 = offset
247                 offset2 = 0
248             if offset1 == "pubkey":
249                 real_offset = 107
250             elif offset1 in o:
251                 real_offset = o[offset1]
252             else:
253                 real_offset = offset1
254             real_offset = int(real_offset) + offset2 + offset_offset
255             assert isinstance(real_offset, int), offset
256             shares[shnum] = flip_bit(data, real_offset)
257     return res
258
259 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
260     def setUp(self):
261         self.client = FakeClient()
262
263     def test_create(self):
264         d = self.client.create_mutable_file()
265         def _created(n):
266             self.failUnless(isinstance(n, FastMutableFileNode))
267             peer0 = self.client._peerids[0]
268             shnums = self.client._storage._peers[peer0].keys()
269             self.failUnlessEqual(len(shnums), 1)
270         d.addCallback(_created)
271         return d
272
273     def test_serialize(self):
274         n = MutableFileNode(self.client)
275         calls = []
276         def _callback(*args, **kwargs):
277             self.failUnlessEqual(args, (4,) )
278             self.failUnlessEqual(kwargs, {"foo": 5})
279             calls.append(1)
280             return 6
281         d = n._do_serialized(_callback, 4, foo=5)
282         def _check_callback(res):
283             self.failUnlessEqual(res, 6)
284             self.failUnlessEqual(calls, [1])
285         d.addCallback(_check_callback)
286
287         def _errback():
288             raise ValueError("heya")
289         d.addCallback(lambda res:
290                       self.shouldFail(ValueError, "_check_errback", "heya",
291                                       n._do_serialized, _errback))
292         return d
293
294     def test_upload_and_download(self):
295         d = self.client.create_mutable_file()
296         def _created(n):
297             d = defer.succeed(None)
298             d.addCallback(lambda res: n.get_servermap(MODE_READ))
299             d.addCallback(lambda smap: smap.dump(StringIO()))
300             d.addCallback(lambda sio:
301                           self.failUnless("3-of-10" in sio.getvalue()))
302             d.addCallback(lambda res: n.overwrite("contents 1"))
303             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
304             d.addCallback(lambda res: n.download_best_version())
305             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
306             d.addCallback(lambda res: n.overwrite("contents 2"))
307             d.addCallback(lambda res: n.download_best_version())
308             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
309             d.addCallback(lambda res: n.download(download.Data()))
310             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
311             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
312             d.addCallback(lambda smap: n.upload("contents 3", smap))
313             d.addCallback(lambda res: n.download_best_version())
314             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
315             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
316             d.addCallback(lambda smap:
317                           n.download_version(smap,
318                                              smap.best_recoverable_version()))
319             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
320             # test a file that is large enough to overcome the
321             # mapupdate-to-retrieve data caching (i.e. make the shares larger
322             # than the default readsize, which is 2000 bytes). A 15kB file
323             # will have 5kB shares.
324             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
325             d.addCallback(lambda res: n.download_best_version())
326             d.addCallback(lambda res:
327                           self.failUnlessEqual(res, "large size file" * 1000))
328             return d
329         d.addCallback(_created)
330         return d
331
332     def test_create_with_initial_contents(self):
333         d = self.client.create_mutable_file("contents 1")
334         def _created(n):
335             d = n.download_best_version()
336             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
337             d.addCallback(lambda res: n.overwrite("contents 2"))
338             d.addCallback(lambda res: n.download_best_version())
339             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
340             return d
341         d.addCallback(_created)
342         return d
343
344     def test_create_with_too_large_contents(self):
345         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
346         d = self.shouldFail(FileTooLargeError, "too_large",
347                             "SDMF is limited to one segment, and %d > %d" %
348                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
349                             self.client.create_mutable_file, BIG)
350         d.addCallback(lambda res: self.client.create_mutable_file("small"))
351         def _created(n):
352             return self.shouldFail(FileTooLargeError, "too_large_2",
353                                    "SDMF is limited to one segment, and %d > %d" %
354                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
355                                    n.overwrite, BIG)
356         d.addCallback(_created)
357         return d
358
359     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
360         d = n.get_servermap(MODE_READ)
361         d.addCallback(lambda servermap: servermap.best_recoverable_version())
362         d.addCallback(lambda verinfo:
363                       self.failUnlessEqual(verinfo[0], expected_seqnum))
364         return d
365
366     def test_modify(self):
367         def _modifier(old_contents):
368             return old_contents + "line2"
369         def _non_modifier(old_contents):
370             return old_contents
371         def _none_modifier(old_contents):
372             return None
373         def _error_modifier(old_contents):
374             raise ValueError("oops")
375         def _toobig_modifier(old_contents):
376             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
377         calls = []
378         def _ucw_error_modifier(old_contents):
379             # simulate an UncoordinatedWriteError once
380             calls.append(1)
381             if len(calls) <= 1:
382                 raise UncoordinatedWriteError("simulated")
383             return old_contents + "line3"
384
385         d = self.client.create_mutable_file("line1")
386         def _created(n):
387             d = n.modify(_modifier)
388             d.addCallback(lambda res: n.download_best_version())
389             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
390             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
391
392             d.addCallback(lambda res: n.modify(_non_modifier))
393             d.addCallback(lambda res: n.download_best_version())
394             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
395             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
396
397             d.addCallback(lambda res: n.modify(_none_modifier))
398             d.addCallback(lambda res: n.download_best_version())
399             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
400             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
401
402             d.addCallback(lambda res:
403                           self.shouldFail(ValueError, "error_modifier", None,
404                                           n.modify, _error_modifier))
405             d.addCallback(lambda res: n.download_best_version())
406             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
407             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
408
409             d.addCallback(lambda res:
410                           self.shouldFail(FileTooLargeError, "toobig_modifier",
411                                           "SDMF is limited to one segment",
412                                           n.modify, _toobig_modifier))
413             d.addCallback(lambda res: n.download_best_version())
414             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
415             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
416
417             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
418             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
419             d.addCallback(lambda res: n.download_best_version())
420             d.addCallback(lambda res: self.failUnlessEqual(res,
421                                                            "line1line2line3"))
422             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
423
424             return d
425         d.addCallback(_created)
426         return d
427
428     def test_modify_backoffer(self):
429         def _modifier(old_contents):
430             return old_contents + "line2"
431         calls = []
432         def _ucw_error_modifier(old_contents):
433             # simulate an UncoordinatedWriteError once
434             calls.append(1)
435             if len(calls) <= 1:
436                 raise UncoordinatedWriteError("simulated")
437             return old_contents + "line3"
438         def _always_ucw_error_modifier(old_contents):
439             raise UncoordinatedWriteError("simulated")
440         def _backoff_stopper(node, f):
441             return f
442         def _backoff_pauser(node, f):
443             d = defer.Deferred()
444             reactor.callLater(0.5, d.callback, None)
445             return d
446
447         # the give-up-er will hit its maximum retry count quickly
448         giveuper = BackoffAgent()
449         giveuper._delay = 0.1
450         giveuper.factor = 1
451
452         d = self.client.create_mutable_file("line1")
453         def _created(n):
454             d = n.modify(_modifier)
455             d.addCallback(lambda res: n.download_best_version())
456             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
457             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
458
459             d.addCallback(lambda res:
460                           self.shouldFail(UncoordinatedWriteError,
461                                           "_backoff_stopper", None,
462                                           n.modify, _ucw_error_modifier,
463                                           _backoff_stopper))
464             d.addCallback(lambda res: n.download_best_version())
465             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
466             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
467
468             def _reset_ucw_error_modifier(res):
469                 calls[:] = []
470                 return res
471             d.addCallback(_reset_ucw_error_modifier)
472             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
473                                                _backoff_pauser))
474             d.addCallback(lambda res: n.download_best_version())
475             d.addCallback(lambda res: self.failUnlessEqual(res,
476                                                            "line1line2line3"))
477             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
478
479             d.addCallback(lambda res:
480                           self.shouldFail(UncoordinatedWriteError,
481                                           "giveuper", None,
482                                           n.modify, _always_ucw_error_modifier,
483                                           giveuper.delay))
484             d.addCallback(lambda res: n.download_best_version())
485             d.addCallback(lambda res: self.failUnlessEqual(res,
486                                                            "line1line2line3"))
487             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
488
489             return d
490         d.addCallback(_created)
491         return d
492
493     def test_upload_and_download_full_size_keys(self):
494         self.client.mutable_file_node_class = MutableFileNode
495         d = self.client.create_mutable_file()
496         def _created(n):
497             d = defer.succeed(None)
498             d.addCallback(lambda res: n.get_servermap(MODE_READ))
499             d.addCallback(lambda smap: smap.dump(StringIO()))
500             d.addCallback(lambda sio:
501                           self.failUnless("3-of-10" in sio.getvalue()))
502             d.addCallback(lambda res: n.overwrite("contents 1"))
503             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
504             d.addCallback(lambda res: n.download_best_version())
505             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
506             d.addCallback(lambda res: n.overwrite("contents 2"))
507             d.addCallback(lambda res: n.download_best_version())
508             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
509             d.addCallback(lambda res: n.download(download.Data()))
510             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
511             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
512             d.addCallback(lambda smap: n.upload("contents 3", smap))
513             d.addCallback(lambda res: n.download_best_version())
514             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
515             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
516             d.addCallback(lambda smap:
517                           n.download_version(smap,
518                                              smap.best_recoverable_version()))
519             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
520             return d
521         d.addCallback(_created)
522         return d
523
524
525 class MakeShares(unittest.TestCase):
526     def test_encrypt(self):
527         c = FakeClient()
528         fn = FastMutableFileNode(c)
529         CONTENTS = "some initial contents"
530         d = fn.create(CONTENTS)
531         def _created(res):
532             p = Publish(fn, None)
533             p.salt = "SALT" * 4
534             p.readkey = "\x00" * 16
535             p.newdata = CONTENTS
536             p.required_shares = 3
537             p.total_shares = 10
538             p.setup_encoding_parameters()
539             return p._encrypt_and_encode()
540         d.addCallback(_created)
541         def _done(shares_and_shareids):
542             (shares, share_ids) = shares_and_shareids
543             self.failUnlessEqual(len(shares), 10)
544             for sh in shares:
545                 self.failUnless(isinstance(sh, str))
546                 self.failUnlessEqual(len(sh), 7)
547             self.failUnlessEqual(len(share_ids), 10)
548         d.addCallback(_done)
549         return d
550
551     def test_generate(self):
552         c = FakeClient()
553         fn = FastMutableFileNode(c)
554         CONTENTS = "some initial contents"
555         d = fn.create(CONTENTS)
556         def _created(res):
557             p = Publish(fn, None)
558             self._p = p
559             p.newdata = CONTENTS
560             p.required_shares = 3
561             p.total_shares = 10
562             p.setup_encoding_parameters()
563             p._new_seqnum = 3
564             p.salt = "SALT" * 4
565             # make some fake shares
566             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
567             p._privkey = fn.get_privkey()
568             p._encprivkey = fn.get_encprivkey()
569             p._pubkey = fn.get_pubkey()
570             return p._generate_shares(shares_and_ids)
571         d.addCallback(_created)
572         def _generated(res):
573             p = self._p
574             final_shares = p.shares
575             root_hash = p.root_hash
576             self.failUnlessEqual(len(root_hash), 32)
577             self.failUnless(isinstance(final_shares, dict))
578             self.failUnlessEqual(len(final_shares), 10)
579             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
580             for i,sh in final_shares.items():
581                 self.failUnless(isinstance(sh, str))
582                 # feed the share through the unpacker as a sanity-check
583                 pieces = unpack_share(sh)
584                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
585                  pubkey, signature, share_hash_chain, block_hash_tree,
586                  share_data, enc_privkey) = pieces
587                 self.failUnlessEqual(u_seqnum, 3)
588                 self.failUnlessEqual(u_root_hash, root_hash)
589                 self.failUnlessEqual(k, 3)
590                 self.failUnlessEqual(N, 10)
591                 self.failUnlessEqual(segsize, 21)
592                 self.failUnlessEqual(datalen, len(CONTENTS))
593                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
594                 sig_material = struct.pack(">BQ32s16s BBQQ",
595                                            0, p._new_seqnum, root_hash, IV,
596                                            k, N, segsize, datalen)
597                 self.failUnless(p._pubkey.verify(sig_material, signature))
598                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
599                 self.failUnless(isinstance(share_hash_chain, dict))
600                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
601                 for shnum,share_hash in share_hash_chain.items():
602                     self.failUnless(isinstance(shnum, int))
603                     self.failUnless(isinstance(share_hash, str))
604                     self.failUnlessEqual(len(share_hash), 32)
605                 self.failUnless(isinstance(block_hash_tree, list))
606                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
607                 self.failUnlessEqual(IV, "SALT"*4)
608                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
609                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
610         d.addCallback(_generated)
611         return d
612
613     # TODO: when we publish to 20 peers, we should get one share per peer on 10
614     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
615     # when we publish to zero peers, we should get a NotEnoughSharesError
616
617 class PublishMixin:
618     def publish_one(self):
619         # publish a file and create shares, which can then be manipulated
620         # later.
621         self.CONTENTS = "New contents go here" * 1000
622         num_peers = 20
623         self._client = FakeClient(num_peers)
624         self._storage = self._client._storage
625         d = self._client.create_mutable_file(self.CONTENTS)
626         def _created(node):
627             self._fn = node
628             self._fn2 = self._client.create_node_from_uri(node.get_uri())
629         d.addCallback(_created)
630         return d
631     def publish_multiple(self):
632         self.CONTENTS = ["Contents 0",
633                          "Contents 1",
634                          "Contents 2",
635                          "Contents 3a",
636                          "Contents 3b"]
637         self._copied_shares = {}
638         num_peers = 20
639         self._client = FakeClient(num_peers)
640         self._storage = self._client._storage
641         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
642         def _created(node):
643             self._fn = node
644             # now create multiple versions of the same file, and accumulate
645             # their shares, so we can mix and match them later.
646             d = defer.succeed(None)
647             d.addCallback(self._copy_shares, 0)
648             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
649             d.addCallback(self._copy_shares, 1)
650             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
651             d.addCallback(self._copy_shares, 2)
652             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
653             d.addCallback(self._copy_shares, 3)
654             # now we replace all the shares with version s3, and upload a new
655             # version to get s4b.
656             rollback = dict([(i,2) for i in range(10)])
657             d.addCallback(lambda res: self._set_versions(rollback))
658             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
659             d.addCallback(self._copy_shares, 4)
660             # we leave the storage in state 4
661             return d
662         d.addCallback(_created)
663         return d
664
665     def _copy_shares(self, ignored, index):
666         shares = self._client._storage._peers
667         # we need a deep copy
668         new_shares = {}
669         for peerid in shares:
670             new_shares[peerid] = {}
671             for shnum in shares[peerid]:
672                 new_shares[peerid][shnum] = shares[peerid][shnum]
673         self._copied_shares[index] = new_shares
674
675     def _set_versions(self, versionmap):
676         # versionmap maps shnums to which version (0,1,2,3,4) we want the
677         # share to be at. Any shnum which is left out of the map will stay at
678         # its current version.
679         shares = self._client._storage._peers
680         oldshares = self._copied_shares
681         for peerid in shares:
682             for shnum in shares[peerid]:
683                 if shnum in versionmap:
684                     index = versionmap[shnum]
685                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
686
687
688 class Servermap(unittest.TestCase, PublishMixin):
689     def setUp(self):
690         return self.publish_one()
691
692     def make_servermap(self, mode=MODE_CHECK, fn=None):
693         if fn is None:
694             fn = self._fn
695         smu = ServermapUpdater(fn, ServerMap(), mode)
696         d = smu.update()
697         return d
698
699     def update_servermap(self, oldmap, mode=MODE_CHECK):
700         smu = ServermapUpdater(self._fn, oldmap, mode)
701         d = smu.update()
702         return d
703
704     def failUnlessOneRecoverable(self, sm, num_shares):
705         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
706         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
707         best = sm.best_recoverable_version()
708         self.failIfEqual(best, None)
709         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
710         self.failUnlessEqual(len(sm.shares_available()), 1)
711         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
712         shnum, peerids = sm.make_sharemap().items()[0]
713         peerid = list(peerids)[0]
714         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
715         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
716         return sm
717
718     def test_basic(self):
719         d = defer.succeed(None)
720         ms = self.make_servermap
721         us = self.update_servermap
722
723         d.addCallback(lambda res: ms(mode=MODE_CHECK))
724         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
725         d.addCallback(lambda res: ms(mode=MODE_WRITE))
726         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
727         d.addCallback(lambda res: ms(mode=MODE_READ))
728         # this more stops at k+epsilon, and epsilon=k, so 6 shares
729         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
730         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
731         # this mode stops at 'k' shares
732         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
733
734         # and can we re-use the same servermap? Note that these are sorted in
735         # increasing order of number of servers queried, since once a server
736         # gets into the servermap, we'll always ask it for an update.
737         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
738         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
739         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
740         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
741         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
742         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
743         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
744         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
745         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
746
747         return d
748
749     def test_fetch_privkey(self):
750         d = defer.succeed(None)
751         # use the sibling filenode (which hasn't been used yet), and make
752         # sure it can fetch the privkey. The file is small, so the privkey
753         # will be fetched on the first (query) pass.
754         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
755         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
756
757         # create a new file, which is large enough to knock the privkey out
758         # of the early part of the file
759         LARGE = "These are Larger contents" * 200 # about 5KB
760         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
761         def _created(large_fn):
762             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
763             return self.make_servermap(MODE_WRITE, large_fn2)
764         d.addCallback(_created)
765         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
766         return d
767
768     def test_mark_bad(self):
769         d = defer.succeed(None)
770         ms = self.make_servermap
771         us = self.update_servermap
772
773         d.addCallback(lambda res: ms(mode=MODE_READ))
774         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
775         def _made_map(sm):
776             v = sm.best_recoverable_version()
777             vm = sm.make_versionmap()
778             shares = list(vm[v])
779             self.failUnlessEqual(len(shares), 6)
780             self._corrupted = set()
781             # mark the first 5 shares as corrupt, then update the servermap.
782             # The map should not have the marked shares it in any more, and
783             # new shares should be found to replace the missing ones.
784             for (shnum, peerid, timestamp) in shares:
785                 if shnum < 5:
786                     self._corrupted.add( (peerid, shnum) )
787                     sm.mark_bad_share(peerid, shnum, "")
788             return self.update_servermap(sm, MODE_WRITE)
789         d.addCallback(_made_map)
790         def _check_map(sm):
791             # this should find all 5 shares that weren't marked bad
792             v = sm.best_recoverable_version()
793             vm = sm.make_versionmap()
794             shares = list(vm[v])
795             for (peerid, shnum) in self._corrupted:
796                 peer_shares = sm.shares_on_peer(peerid)
797                 self.failIf(shnum in peer_shares,
798                             "%d was in %s" % (shnum, peer_shares))
799             self.failUnlessEqual(len(shares), 5)
800         d.addCallback(_check_map)
801         return d
802
803     def failUnlessNoneRecoverable(self, sm):
804         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
805         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
806         best = sm.best_recoverable_version()
807         self.failUnlessEqual(best, None)
808         self.failUnlessEqual(len(sm.shares_available()), 0)
809
810     def test_no_shares(self):
811         self._client._storage._peers = {} # delete all shares
812         ms = self.make_servermap
813         d = defer.succeed(None)
814
815         d.addCallback(lambda res: ms(mode=MODE_CHECK))
816         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
817
818         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
819         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
820
821         d.addCallback(lambda res: ms(mode=MODE_WRITE))
822         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
823
824         d.addCallback(lambda res: ms(mode=MODE_READ))
825         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
826
827         return d
828
829     def failUnlessNotQuiteEnough(self, sm):
830         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
831         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
832         best = sm.best_recoverable_version()
833         self.failUnlessEqual(best, None)
834         self.failUnlessEqual(len(sm.shares_available()), 1)
835         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
836         return sm
837
838     def test_not_quite_enough_shares(self):
839         s = self._client._storage
840         ms = self.make_servermap
841         num_shares = len(s._peers)
842         for peerid in s._peers:
843             s._peers[peerid] = {}
844             num_shares -= 1
845             if num_shares == 2:
846                 break
847         # now there ought to be only two shares left
848         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
849
850         d = defer.succeed(None)
851
852         d.addCallback(lambda res: ms(mode=MODE_CHECK))
853         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
854         d.addCallback(lambda sm:
855                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
856         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
857         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
858         d.addCallback(lambda res: ms(mode=MODE_WRITE))
859         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
860         d.addCallback(lambda res: ms(mode=MODE_READ))
861         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
862
863         return d
864
865
866
867 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
868     def setUp(self):
869         return self.publish_one()
870
871     def make_servermap(self, mode=MODE_READ, oldmap=None):
872         if oldmap is None:
873             oldmap = ServerMap()
874         smu = ServermapUpdater(self._fn, oldmap, mode)
875         d = smu.update()
876         return d
877
878     def abbrev_verinfo(self, verinfo):
879         if verinfo is None:
880             return None
881         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
882          offsets_tuple) = verinfo
883         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
884
885     def abbrev_verinfo_dict(self, verinfo_d):
886         output = {}
887         for verinfo,value in verinfo_d.items():
888             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
889              offsets_tuple) = verinfo
890             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
891         return output
892
893     def dump_servermap(self, servermap):
894         print "SERVERMAP", servermap
895         print "RECOVERABLE", [self.abbrev_verinfo(v)
896                               for v in servermap.recoverable_versions()]
897         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
898         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
899
900     def do_download(self, servermap, version=None):
901         if version is None:
902             version = servermap.best_recoverable_version()
903         r = Retrieve(self._fn, servermap, version)
904         return r.download()
905
906     def test_basic(self):
907         d = self.make_servermap()
908         def _do_retrieve(servermap):
909             self._smap = servermap
910             #self.dump_servermap(servermap)
911             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
912             return self.do_download(servermap)
913         d.addCallback(_do_retrieve)
914         def _retrieved(new_contents):
915             self.failUnlessEqual(new_contents, self.CONTENTS)
916         d.addCallback(_retrieved)
917         # we should be able to re-use the same servermap, both with and
918         # without updating it.
919         d.addCallback(lambda res: self.do_download(self._smap))
920         d.addCallback(_retrieved)
921         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
922         d.addCallback(lambda res: self.do_download(self._smap))
923         d.addCallback(_retrieved)
924         # clobbering the pubkey should make the servermap updater re-fetch it
925         def _clobber_pubkey(res):
926             self._fn._pubkey = None
927         d.addCallback(_clobber_pubkey)
928         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
929         d.addCallback(lambda res: self.do_download(self._smap))
930         d.addCallback(_retrieved)
931         return d
932
933     def test_all_shares_vanished(self):
934         d = self.make_servermap()
935         def _remove_shares(servermap):
936             for shares in self._storage._peers.values():
937                 shares.clear()
938             d1 = self.shouldFail(NotEnoughSharesError,
939                                  "test_all_shares_vanished",
940                                  "ran out of peers",
941                                  self.do_download, servermap)
942             return d1
943         d.addCallback(_remove_shares)
944         return d
945
946     def test_no_servers(self):
947         c2 = FakeClient(0)
948         self._fn._client = c2
949         # if there are no servers, then a MODE_READ servermap should come
950         # back empty
951         d = self.make_servermap()
952         def _check_servermap(servermap):
953             self.failUnlessEqual(servermap.best_recoverable_version(), None)
954             self.failIf(servermap.recoverable_versions())
955             self.failIf(servermap.unrecoverable_versions())
956             self.failIf(servermap.all_peers())
957         d.addCallback(_check_servermap)
958         return d
959     test_no_servers.timeout = 15
960
961     def test_no_servers_download(self):
962         c2 = FakeClient(0)
963         self._fn._client = c2
964         d = self.shouldFail(UnrecoverableFileError,
965                             "test_no_servers_download",
966                             "no recoverable versions",
967                             self._fn.download_best_version)
968         def _restore(res):
969             # a failed download that occurs while we aren't connected to
970             # anybody should not prevent a subsequent download from working.
971             # This isn't quite the webapi-driven test that #463 wants, but it
972             # should be close enough.
973             self._fn._client = self._client
974             return self._fn.download_best_version()
975         def _retrieved(new_contents):
976             self.failUnlessEqual(new_contents, self.CONTENTS)
977         d.addCallback(_restore)
978         d.addCallback(_retrieved)
979         return d
980     test_no_servers_download.timeout = 15
981
982     def _test_corrupt_all(self, offset, substring,
983                           should_succeed=False, corrupt_early=True,
984                           failure_checker=None):
985         d = defer.succeed(None)
986         if corrupt_early:
987             d.addCallback(corrupt, self._storage, offset)
988         d.addCallback(lambda res: self.make_servermap())
989         if not corrupt_early:
990             d.addCallback(corrupt, self._storage, offset)
991         def _do_retrieve(servermap):
992             ver = servermap.best_recoverable_version()
993             if ver is None and not should_succeed:
994                 # no recoverable versions == not succeeding. The problem
995                 # should be noted in the servermap's list of problems.
996                 if substring:
997                     allproblems = [str(f) for f in servermap.problems]
998                     self.failUnless(substring in "".join(allproblems))
999                 return servermap
1000             if should_succeed:
1001                 d1 = self._fn.download_version(servermap, ver)
1002                 d1.addCallback(lambda new_contents:
1003                                self.failUnlessEqual(new_contents, self.CONTENTS))
1004             else:
1005                 d1 = self.shouldFail(NotEnoughSharesError,
1006                                      "_corrupt_all(offset=%s)" % (offset,),
1007                                      substring,
1008                                      self._fn.download_version, servermap, ver)
1009             if failure_checker:
1010                 d1.addCallback(failure_checker)
1011             d1.addCallback(lambda res: servermap)
1012             return d1
1013         d.addCallback(_do_retrieve)
1014         return d
1015
1016     def test_corrupt_all_verbyte(self):
1017         # when the version byte is not 0, we hit an assertion error in
1018         # unpack_share().
1019         d = self._test_corrupt_all(0, "AssertionError")
1020         def _check_servermap(servermap):
1021             # and the dump should mention the problems
1022             s = StringIO()
1023             dump = servermap.dump(s).getvalue()
1024             self.failUnless("10 PROBLEMS" in dump, dump)
1025         d.addCallback(_check_servermap)
1026         return d
1027
1028     def test_corrupt_all_seqnum(self):
1029         # a corrupt sequence number will trigger a bad signature
1030         return self._test_corrupt_all(1, "signature is invalid")
1031
1032     def test_corrupt_all_R(self):
1033         # a corrupt root hash will trigger a bad signature
1034         return self._test_corrupt_all(9, "signature is invalid")
1035
1036     def test_corrupt_all_IV(self):
1037         # a corrupt salt/IV will trigger a bad signature
1038         return self._test_corrupt_all(41, "signature is invalid")
1039
1040     def test_corrupt_all_k(self):
1041         # a corrupt 'k' will trigger a bad signature
1042         return self._test_corrupt_all(57, "signature is invalid")
1043
1044     def test_corrupt_all_N(self):
1045         # a corrupt 'N' will trigger a bad signature
1046         return self._test_corrupt_all(58, "signature is invalid")
1047
1048     def test_corrupt_all_segsize(self):
1049         # a corrupt segsize will trigger a bad signature
1050         return self._test_corrupt_all(59, "signature is invalid")
1051
1052     def test_corrupt_all_datalen(self):
1053         # a corrupt data length will trigger a bad signature
1054         return self._test_corrupt_all(67, "signature is invalid")
1055
1056     def test_corrupt_all_pubkey(self):
1057         # a corrupt pubkey won't match the URI's fingerprint. We need to
1058         # remove the pubkey from the filenode, or else it won't bother trying
1059         # to update it.
1060         self._fn._pubkey = None
1061         return self._test_corrupt_all("pubkey",
1062                                       "pubkey doesn't match fingerprint")
1063
1064     def test_corrupt_all_sig(self):
1065         # a corrupt signature is a bad one
1066         # the signature runs from about [543:799], depending upon the length
1067         # of the pubkey
1068         return self._test_corrupt_all("signature", "signature is invalid")
1069
1070     def test_corrupt_all_share_hash_chain_number(self):
1071         # a corrupt share hash chain entry will show up as a bad hash. If we
1072         # mangle the first byte, that will look like a bad hash number,
1073         # causing an IndexError
1074         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1075
1076     def test_corrupt_all_share_hash_chain_hash(self):
1077         # a corrupt share hash chain entry will show up as a bad hash. If we
1078         # mangle a few bytes in, that will look like a bad hash.
1079         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1080
1081     def test_corrupt_all_block_hash_tree(self):
1082         return self._test_corrupt_all("block_hash_tree",
1083                                       "block hash tree failure")
1084
1085     def test_corrupt_all_block(self):
1086         return self._test_corrupt_all("share_data", "block hash tree failure")
1087
1088     def test_corrupt_all_encprivkey(self):
1089         # a corrupted privkey won't even be noticed by the reader, only by a
1090         # writer.
1091         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1092
1093
1094     def test_corrupt_all_seqnum_late(self):
1095         # corrupting the seqnum between mapupdate and retrieve should result
1096         # in NotEnoughSharesError, since each share will look invalid
1097         def _check(res):
1098             f = res[0]
1099             self.failUnless(f.check(NotEnoughSharesError))
1100             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1101         return self._test_corrupt_all(1, "ran out of peers",
1102                                       corrupt_early=False,
1103                                       failure_checker=_check)
1104
1105     def test_corrupt_all_block_hash_tree_late(self):
1106         def _check(res):
1107             f = res[0]
1108             self.failUnless(f.check(NotEnoughSharesError))
1109         return self._test_corrupt_all("block_hash_tree",
1110                                       "block hash tree failure",
1111                                       corrupt_early=False,
1112                                       failure_checker=_check)
1113
1114
1115     def test_corrupt_all_block_late(self):
1116         def _check(res):
1117             f = res[0]
1118             self.failUnless(f.check(NotEnoughSharesError))
1119         return self._test_corrupt_all("share_data", "block hash tree failure",
1120                                       corrupt_early=False,
1121                                       failure_checker=_check)
1122
1123
1124     def test_basic_pubkey_at_end(self):
1125         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1126         # download to succeed but forcing a bunch of retries first. Note that
1127         # this is rather pessimistic: our Retrieve process will throw away
1128         # the whole share if the pubkey is bad, even though the rest of the
1129         # share might be good.
1130
1131         self._fn._pubkey = None
1132         k = self._fn.get_required_shares()
1133         N = self._fn.get_total_shares()
1134         d = defer.succeed(None)
1135         d.addCallback(corrupt, self._storage, "pubkey",
1136                       shnums_to_corrupt=range(0, N-k))
1137         d.addCallback(lambda res: self.make_servermap())
1138         def _do_retrieve(servermap):
1139             self.failUnless(servermap.problems)
1140             self.failUnless("pubkey doesn't match fingerprint"
1141                             in str(servermap.problems[0]))
1142             ver = servermap.best_recoverable_version()
1143             r = Retrieve(self._fn, servermap, ver)
1144             return r.download()
1145         d.addCallback(_do_retrieve)
1146         d.addCallback(lambda new_contents:
1147                       self.failUnlessEqual(new_contents, self.CONTENTS))
1148         return d
1149
1150     def test_corrupt_some(self):
1151         # corrupt the data of first five shares (so the servermap thinks
1152         # they're good but retrieve marks them as bad), so that the
1153         # MODE_READ set of 6 will be insufficient, forcing node.download to
1154         # retry with more servers.
1155         corrupt(None, self._storage, "share_data", range(5))
1156         d = self.make_servermap()
1157         def _do_retrieve(servermap):
1158             ver = servermap.best_recoverable_version()
1159             self.failUnless(ver)
1160             return self._fn.download_best_version()
1161         d.addCallback(_do_retrieve)
1162         d.addCallback(lambda new_contents:
1163                       self.failUnlessEqual(new_contents, self.CONTENTS))
1164         return d
1165
1166     def test_download_fails(self):
1167         corrupt(None, self._storage, "signature")
1168         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1169                             "no recoverable versions",
1170                             self._fn.download_best_version)
1171         return d
1172
1173
1174 class CheckerMixin:
1175     def check_good(self, r, where):
1176         self.failUnless(r.healthy, where)
1177         self.failIf(r.problems, where)
1178         return r
1179
1180     def check_bad(self, r, where):
1181         self.failIf(r.healthy, where)
1182         return r
1183
1184     def check_expected_failure(self, r, expected_exception, substring, where):
1185         for (peerid, storage_index, shnum, f) in r.problems:
1186             if f.check(expected_exception):
1187                 self.failUnless(substring in str(f),
1188                                 "%s: substring '%s' not in '%s'" %
1189                                 (where, substring, str(f)))
1190                 return
1191         self.fail("%s: didn't see expected exception %s in problems %s" %
1192                   (where, expected_exception, r.problems))
1193
1194
1195 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1196     def setUp(self):
1197         return self.publish_one()
1198
1199
1200     def test_check_good(self):
1201         d = self._fn.check()
1202         d.addCallback(self.check_good, "test_check_good")
1203         return d
1204
1205     def test_check_no_shares(self):
1206         for shares in self._storage._peers.values():
1207             shares.clear()
1208         d = self._fn.check()
1209         d.addCallback(self.check_bad, "test_check_no_shares")
1210         return d
1211
1212     def test_check_not_enough_shares(self):
1213         for shares in self._storage._peers.values():
1214             for shnum in shares.keys():
1215                 if shnum > 0:
1216                     del shares[shnum]
1217         d = self._fn.check()
1218         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1219         return d
1220
1221     def test_check_all_bad_sig(self):
1222         corrupt(None, self._storage, 1) # bad sig
1223         d = self._fn.check()
1224         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1225         return d
1226
1227     def test_check_all_bad_blocks(self):
1228         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1229         # the Checker won't notice this.. it doesn't look at actual data
1230         d = self._fn.check()
1231         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1232         return d
1233
1234     def test_verify_good(self):
1235         d = self._fn.check(verify=True)
1236         d.addCallback(self.check_good, "test_verify_good")
1237         return d
1238
1239     def test_verify_all_bad_sig(self):
1240         corrupt(None, self._storage, 1) # bad sig
1241         d = self._fn.check(verify=True)
1242         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1243         return d
1244
1245     def test_verify_one_bad_sig(self):
1246         corrupt(None, self._storage, 1, [9]) # bad sig
1247         d = self._fn.check(verify=True)
1248         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1249         return d
1250
1251     def test_verify_one_bad_block(self):
1252         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1253         # the Verifier *will* notice this, since it examines every byte
1254         d = self._fn.check(verify=True)
1255         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1256         d.addCallback(self.check_expected_failure,
1257                       CorruptShareError, "block hash tree failure",
1258                       "test_verify_one_bad_block")
1259         return d
1260
1261     def test_verify_one_bad_sharehash(self):
1262         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1263         d = self._fn.check(verify=True)
1264         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1265         d.addCallback(self.check_expected_failure,
1266                       CorruptShareError, "corrupt hashes",
1267                       "test_verify_one_bad_sharehash")
1268         return d
1269
1270     def test_verify_one_bad_encprivkey(self):
1271         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1272         d = self._fn.check(verify=True)
1273         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1274         d.addCallback(self.check_expected_failure,
1275                       CorruptShareError, "invalid privkey",
1276                       "test_verify_one_bad_encprivkey")
1277         return d
1278
1279     def test_verify_one_bad_encprivkey_uncheckable(self):
1280         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1281         readonly_fn = self._fn.get_readonly()
1282         # a read-only node has no way to validate the privkey
1283         d = readonly_fn.check(verify=True)
1284         d.addCallback(self.check_good,
1285                       "test_verify_one_bad_encprivkey_uncheckable")
1286         return d
1287
1288 class Repair(unittest.TestCase, PublishMixin):
1289
1290     def get_shares(self, s):
1291         all_shares = {} # maps (peerid, shnum) to share data
1292         for peerid in s._peers:
1293             shares = s._peers[peerid]
1294             for shnum in shares:
1295                 data = shares[shnum]
1296                 all_shares[ (peerid, shnum) ] = data
1297         return all_shares
1298
1299     def test_repair_nop(self):
1300         d = self.publish_one()
1301         def _published(res):
1302             self.initial_shares = self.get_shares(self._storage)
1303         d.addCallback(_published)
1304         d.addCallback(lambda res: self._fn.check())
1305         d.addCallback(lambda check_results: self._fn.repair(check_results))
1306         def _check_results(rres):
1307             self.failUnless(IRepairResults.providedBy(rres))
1308             # TODO: examine results
1309
1310             new_shares = self.get_shares(self._storage)
1311             # all shares should be in the same place as before
1312             self.failUnlessEqual(set(self.initial_shares.keys()),
1313                                  set(new_shares.keys()))
1314             # but they should all be at a newer seqnum. The IV will be
1315             # different, so the roothash will be too.
1316             for key in self.initial_shares:
1317                 (version0,
1318                  seqnum0,
1319                  root_hash0,
1320                  IV0,
1321                  k0, N0, segsize0, datalen0,
1322                  o0) = unpack_header(self.initial_shares[key])
1323                 (version1,
1324                  seqnum1,
1325                  root_hash1,
1326                  IV1,
1327                  k1, N1, segsize1, datalen1,
1328                  o1) = unpack_header(new_shares[key])
1329                 self.failUnlessEqual(version0, version1)
1330                 self.failUnlessEqual(seqnum0+1, seqnum1)
1331                 self.failUnlessEqual(k0, k1)
1332                 self.failUnlessEqual(N0, N1)
1333                 self.failUnlessEqual(segsize0, segsize1)
1334                 self.failUnlessEqual(datalen0, datalen1)
1335         d.addCallback(_check_results)
1336         return d
1337
1338
1339 class MultipleEncodings(unittest.TestCase):
1340     def setUp(self):
1341         self.CONTENTS = "New contents go here"
1342         num_peers = 20
1343         self._client = FakeClient(num_peers)
1344         self._storage = self._client._storage
1345         d = self._client.create_mutable_file(self.CONTENTS)
1346         def _created(node):
1347             self._fn = node
1348         d.addCallback(_created)
1349         return d
1350
1351     def _encode(self, k, n, data):
1352         # encode 'data' into a peerid->shares dict.
1353
1354         fn2 = FastMutableFileNode(self._client)
1355         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1356         # and _fingerprint
1357         fn = self._fn
1358         fn2.init_from_uri(fn.get_uri())
1359         # then we copy over other fields that are normally fetched from the
1360         # existing shares
1361         fn2._pubkey = fn._pubkey
1362         fn2._privkey = fn._privkey
1363         fn2._encprivkey = fn._encprivkey
1364         # and set the encoding parameters to something completely different
1365         fn2._required_shares = k
1366         fn2._total_shares = n
1367
1368         s = self._client._storage
1369         s._peers = {} # clear existing storage
1370         p2 = Publish(fn2, None)
1371         d = p2.publish(data)
1372         def _published(res):
1373             shares = s._peers
1374             s._peers = {}
1375             return shares
1376         d.addCallback(_published)
1377         return d
1378
1379     def make_servermap(self, mode=MODE_READ, oldmap=None):
1380         if oldmap is None:
1381             oldmap = ServerMap()
1382         smu = ServermapUpdater(self._fn, oldmap, mode)
1383         d = smu.update()
1384         return d
1385
1386     def test_multiple_encodings(self):
1387         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1388         # then mix up the shares, to make sure that download survives seeing
1389         # a variety of encodings. This is actually kind of tricky to set up.
1390
1391         contents1 = "Contents for encoding 1 (3-of-10) go here"
1392         contents2 = "Contents for encoding 2 (4-of-9) go here"
1393         contents3 = "Contents for encoding 3 (4-of-7) go here"
1394
1395         # we make a retrieval object that doesn't know what encoding
1396         # parameters to use
1397         fn3 = FastMutableFileNode(self._client)
1398         fn3.init_from_uri(self._fn.get_uri())
1399
1400         # now we upload a file through fn1, and grab its shares
1401         d = self._encode(3, 10, contents1)
1402         def _encoded_1(shares):
1403             self._shares1 = shares
1404         d.addCallback(_encoded_1)
1405         d.addCallback(lambda res: self._encode(4, 9, contents2))
1406         def _encoded_2(shares):
1407             self._shares2 = shares
1408         d.addCallback(_encoded_2)
1409         d.addCallback(lambda res: self._encode(4, 7, contents3))
1410         def _encoded_3(shares):
1411             self._shares3 = shares
1412         d.addCallback(_encoded_3)
1413
1414         def _merge(res):
1415             log.msg("merging sharelists")
1416             # we merge the shares from the two sets, leaving each shnum in
1417             # its original location, but using a share from set1 or set2
1418             # according to the following sequence:
1419             #
1420             #  4-of-9  a  s2
1421             #  4-of-9  b  s2
1422             #  4-of-7  c   s3
1423             #  4-of-9  d  s2
1424             #  3-of-9  e s1
1425             #  3-of-9  f s1
1426             #  3-of-9  g s1
1427             #  4-of-9  h  s2
1428             #
1429             # so that neither form can be recovered until fetch [f], at which
1430             # point version-s1 (the 3-of-10 form) should be recoverable. If
1431             # the implementation latches on to the first version it sees,
1432             # then s2 will be recoverable at fetch [g].
1433
1434             # Later, when we implement code that handles multiple versions,
1435             # we can use this framework to assert that all recoverable
1436             # versions are retrieved, and test that 'epsilon' does its job
1437
1438             places = [2, 2, 3, 2, 1, 1, 1, 2]
1439
1440             sharemap = {}
1441
1442             for i,peerid in enumerate(self._client._peerids):
1443                 peerid_s = shortnodeid_b2a(peerid)
1444                 for shnum in self._shares1.get(peerid, {}):
1445                     if shnum < len(places):
1446                         which = places[shnum]
1447                     else:
1448                         which = "x"
1449                     self._client._storage._peers[peerid] = peers = {}
1450                     in_1 = shnum in self._shares1[peerid]
1451                     in_2 = shnum in self._shares2.get(peerid, {})
1452                     in_3 = shnum in self._shares3.get(peerid, {})
1453                     #print peerid_s, shnum, which, in_1, in_2, in_3
1454                     if which == 1:
1455                         if in_1:
1456                             peers[shnum] = self._shares1[peerid][shnum]
1457                             sharemap[shnum] = peerid
1458                     elif which == 2:
1459                         if in_2:
1460                             peers[shnum] = self._shares2[peerid][shnum]
1461                             sharemap[shnum] = peerid
1462                     elif which == 3:
1463                         if in_3:
1464                             peers[shnum] = self._shares3[peerid][shnum]
1465                             sharemap[shnum] = peerid
1466
1467             # we don't bother placing any other shares
1468             # now sort the sequence so that share 0 is returned first
1469             new_sequence = [sharemap[shnum]
1470                             for shnum in sorted(sharemap.keys())]
1471             self._client._storage._sequence = new_sequence
1472             log.msg("merge done")
1473         d.addCallback(_merge)
1474         d.addCallback(lambda res: fn3.download_best_version())
1475         def _retrieved(new_contents):
1476             # the current specified behavior is "first version recoverable"
1477             self.failUnlessEqual(new_contents, contents1)
1478         d.addCallback(_retrieved)
1479         return d
1480
1481
1482 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1483
1484     def setUp(self):
1485         return self.publish_multiple()
1486
1487     def test_multiple_versions(self):
1488         # if we see a mix of versions in the grid, download_best_version
1489         # should get the latest one
1490         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1491         d = self._fn.download_best_version()
1492         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1493         # and the checker should report problems
1494         d.addCallback(lambda res: self._fn.check())
1495         d.addCallback(self.check_bad, "test_multiple_versions")
1496
1497         # but if everything is at version 2, that's what we should download
1498         d.addCallback(lambda res:
1499                       self._set_versions(dict([(i,2) for i in range(10)])))
1500         d.addCallback(lambda res: self._fn.download_best_version())
1501         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1502         # if exactly one share is at version 3, we should still get v2
1503         d.addCallback(lambda res:
1504                       self._set_versions({0:3}))
1505         d.addCallback(lambda res: self._fn.download_best_version())
1506         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1507         # but the servermap should see the unrecoverable version. This
1508         # depends upon the single newer share being queried early.
1509         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1510         def _check_smap(smap):
1511             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1512             newer = smap.unrecoverable_newer_versions()
1513             self.failUnlessEqual(len(newer), 1)
1514             verinfo, health = newer.items()[0]
1515             self.failUnlessEqual(verinfo[0], 4)
1516             self.failUnlessEqual(health, (1,3))
1517             self.failIf(smap.needs_merge())
1518         d.addCallback(_check_smap)
1519         # if we have a mix of two parallel versions (s4a and s4b), we could
1520         # recover either
1521         d.addCallback(lambda res:
1522                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1523                                           1:4,3:4,5:4,7:4,9:4}))
1524         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1525         def _check_smap_mixed(smap):
1526             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1527             newer = smap.unrecoverable_newer_versions()
1528             self.failUnlessEqual(len(newer), 0)
1529             self.failUnless(smap.needs_merge())
1530         d.addCallback(_check_smap_mixed)
1531         d.addCallback(lambda res: self._fn.download_best_version())
1532         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1533                                                   res == self.CONTENTS[4]))
1534         return d
1535
1536     def test_replace(self):
1537         # if we see a mix of versions in the grid, we should be able to
1538         # replace them all with a newer version
1539
1540         # if exactly one share is at version 3, we should download (and
1541         # replace) v2, and the result should be v4. Note that the index we
1542         # give to _set_versions is different than the sequence number.
1543         target = dict([(i,2) for i in range(10)]) # seqnum3
1544         target[0] = 3 # seqnum4
1545         self._set_versions(target)
1546
1547         def _modify(oldversion):
1548             return oldversion + " modified"
1549         d = self._fn.modify(_modify)
1550         d.addCallback(lambda res: self._fn.download_best_version())
1551         expected = self.CONTENTS[2] + " modified"
1552         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1553         # and the servermap should indicate that the outlier was replaced too
1554         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1555         def _check_smap(smap):
1556             self.failUnlessEqual(smap.highest_seqnum(), 5)
1557             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1558             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1559         d.addCallback(_check_smap)
1560         return d
1561
1562
1563 class Utils(unittest.TestCase):
1564     def test_dict_of_sets(self):
1565         ds = DictOfSets()
1566         ds.add(1, "a")
1567         ds.add(2, "b")
1568         ds.add(2, "b")
1569         ds.add(2, "c")
1570         self.failUnlessEqual(ds[1], set(["a"]))
1571         self.failUnlessEqual(ds[2], set(["b", "c"]))
1572         ds.discard(3, "d") # should not raise an exception
1573         ds.discard(2, "b")
1574         self.failUnlessEqual(ds[2], set(["c"]))
1575         ds.discard(2, "c")
1576         self.failIf(2 in ds)
1577
1578     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1579         # we compare this against sets of integers
1580         x = set(range(x_start, x_start+x_length))
1581         y = set(range(y_start, y_start+y_length))
1582         should_be_inside = x.issubset(y)
1583         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1584                                                          y_start, y_length),
1585                              str((x_start, x_length, y_start, y_length)))
1586
1587     def test_cache_inside(self):
1588         c = ResponseCache()
1589         x_start = 10
1590         x_length = 5
1591         for y_start in range(8, 17):
1592             for y_length in range(8):
1593                 self._do_inside(c, x_start, x_length, y_start, y_length)
1594
1595     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1596         # we compare this against sets of integers
1597         x = set(range(x_start, x_start+x_length))
1598         y = set(range(y_start, y_start+y_length))
1599         overlap = bool(x.intersection(y))
1600         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1601                                                       y_start, y_length),
1602                              str((x_start, x_length, y_start, y_length)))
1603
1604     def test_cache_overlap(self):
1605         c = ResponseCache()
1606         x_start = 10
1607         x_length = 5
1608         for y_start in range(8, 17):
1609             for y_length in range(8):
1610                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1611
1612     def test_cache(self):
1613         c = ResponseCache()
1614         # xdata = base62.b2a(os.urandom(100))[:100]
1615         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1616         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1617         nope = (None, None)
1618         c.add("v1", 1, 0, xdata, "time0")
1619         c.add("v1", 1, 2000, ydata, "time1")
1620         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1621         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1622         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1623         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1624         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1625         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1626         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1627         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1628         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1629         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1630         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1631         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1632         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1633         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1634         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1635         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1636         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1637         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1638
1639         # optional: join fragments
1640         c = ResponseCache()
1641         c.add("v1", 1, 0, xdata[:10], "time0")
1642         c.add("v1", 1, 10, xdata[10:20], "time1")
1643         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1644
1645 class Exceptions(unittest.TestCase):
1646     def test_repr(self):
1647         nmde = NeedMoreDataError(100, 50, 100)
1648         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1649         ucwe = UncoordinatedWriteError()
1650         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1651
1652 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1653 # instances which always succeed. So we need a less-fake one.
1654
1655 class IntentionalError(Exception):
1656     pass
1657
1658 class LocalWrapper:
1659     def __init__(self, original):
1660         self.original = original
1661         self.broken = False
1662         self.post_call_notifier = None
1663     def callRemote(self, methname, *args, **kwargs):
1664         def _call():
1665             if self.broken:
1666                 raise IntentionalError("I was asked to break")
1667             meth = getattr(self.original, "remote_" + methname)
1668             return meth(*args, **kwargs)
1669         d = fireEventually()
1670         d.addCallback(lambda res: _call())
1671         if self.post_call_notifier:
1672             d.addCallback(self.post_call_notifier, methname)
1673         return d
1674
1675 class LessFakeClient(FakeClient):
1676
1677     def __init__(self, basedir, num_peers=10):
1678         self._num_peers = num_peers
1679         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1680                          for i in range(self._num_peers)]
1681         self._connections = {}
1682         for peerid in self._peerids:
1683             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1684             make_dirs(peerdir)
1685             ss = storage.StorageServer(peerdir)
1686             ss.setNodeID(peerid)
1687             lw = LocalWrapper(ss)
1688             self._connections[peerid] = lw
1689         self.nodeid = "fakenodeid"
1690
1691
1692 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1693     def test_publish_surprise(self):
1694         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1695         self.client = LessFakeClient(basedir)
1696         d = self.client.create_mutable_file("contents 1")
1697         def _created(n):
1698             d = defer.succeed(None)
1699             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1700             def _got_smap1(smap):
1701                 # stash the old state of the file
1702                 self.old_map = smap
1703             d.addCallback(_got_smap1)
1704             # then modify the file, leaving the old map untouched
1705             d.addCallback(lambda res: log.msg("starting winning write"))
1706             d.addCallback(lambda res: n.overwrite("contents 2"))
1707             # now attempt to modify the file with the old servermap. This
1708             # will look just like an uncoordinated write, in which every
1709             # single share got updated between our mapupdate and our publish
1710             d.addCallback(lambda res: log.msg("starting doomed write"))
1711             d.addCallback(lambda res:
1712                           self.shouldFail(UncoordinatedWriteError,
1713                                           "test_publish_surprise", None,
1714                                           n.upload,
1715                                           "contents 2a", self.old_map))
1716             return d
1717         d.addCallback(_created)
1718         return d
1719
1720     def test_retrieve_surprise(self):
1721         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1722         self.client = LessFakeClient(basedir)
1723         d = self.client.create_mutable_file("contents 1")
1724         def _created(n):
1725             d = defer.succeed(None)
1726             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1727             def _got_smap1(smap):
1728                 # stash the old state of the file
1729                 self.old_map = smap
1730             d.addCallback(_got_smap1)
1731             # then modify the file, leaving the old map untouched
1732             d.addCallback(lambda res: log.msg("starting winning write"))
1733             d.addCallback(lambda res: n.overwrite("contents 2"))
1734             # now attempt to retrieve the old version with the old servermap.
1735             # This will look like someone has changed the file since we
1736             # updated the servermap.
1737             d.addCallback(lambda res: n._cache._clear())
1738             d.addCallback(lambda res: log.msg("starting doomed read"))
1739             d.addCallback(lambda res:
1740                           self.shouldFail(NotEnoughSharesError,
1741                                           "test_retrieve_surprise",
1742                                           "ran out of peers: have 0 shares (k=3)",
1743                                           n.download_version,
1744                                           self.old_map,
1745                                           self.old_map.best_recoverable_version(),
1746                                           ))
1747             return d
1748         d.addCallback(_created)
1749         return d
1750
1751     def test_unexpected_shares(self):
1752         # upload the file, take a servermap, shut down one of the servers,
1753         # upload it again (causing shares to appear on a new server), then
1754         # upload using the old servermap. The last upload should fail with an
1755         # UncoordinatedWriteError, because of the shares that didn't appear
1756         # in the servermap.
1757         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1758         self.client = LessFakeClient(basedir)
1759         d = self.client.create_mutable_file("contents 1")
1760         def _created(n):
1761             d = defer.succeed(None)
1762             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1763             def _got_smap1(smap):
1764                 # stash the old state of the file
1765                 self.old_map = smap
1766                 # now shut down one of the servers
1767                 peer0 = list(smap.make_sharemap()[0])[0]
1768                 self.client._connections.pop(peer0)
1769                 # then modify the file, leaving the old map untouched
1770                 log.msg("starting winning write")
1771                 return n.overwrite("contents 2")
1772             d.addCallback(_got_smap1)
1773             # now attempt to modify the file with the old servermap. This
1774             # will look just like an uncoordinated write, in which every
1775             # single share got updated between our mapupdate and our publish
1776             d.addCallback(lambda res: log.msg("starting doomed write"))
1777             d.addCallback(lambda res:
1778                           self.shouldFail(UncoordinatedWriteError,
1779                                           "test_surprise", None,
1780                                           n.upload,
1781                                           "contents 2a", self.old_map))
1782             return d
1783         d.addCallback(_created)
1784         return d
1785
1786     def test_bad_server(self):
1787         # Break one server, then create the file: the initial publish should
1788         # complete with an alternate server. Breaking a second server should
1789         # not prevent an update from succeeding either.
1790         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1791         self.client = LessFakeClient(basedir, 20)
1792         # to make sure that one of the initial peers is broken, we have to
1793         # get creative. We create the keys, so we can figure out the storage
1794         # index, but we hold off on doing the initial publish until we've
1795         # broken the server on which the first share wants to be stored.
1796         n = FastMutableFileNode(self.client)
1797         d = defer.succeed(None)
1798         d.addCallback(n._generate_pubprivkeys)
1799         d.addCallback(n._generated)
1800         def _break_peer0(res):
1801             si = n.get_storage_index()
1802             peerlist = self.client.get_permuted_peers("storage", si)
1803             peerid0, connection0 = peerlist[0]
1804             peerid1, connection1 = peerlist[1]
1805             connection0.broken = True
1806             self.connection1 = connection1
1807         d.addCallback(_break_peer0)
1808         # now let the initial publish finally happen
1809         d.addCallback(lambda res: n._upload("contents 1", None))
1810         # that ought to work
1811         d.addCallback(lambda res: n.download_best_version())
1812         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1813         # now break the second peer
1814         def _break_peer1(res):
1815             self.connection1.broken = True
1816         d.addCallback(_break_peer1)
1817         d.addCallback(lambda res: n.overwrite("contents 2"))
1818         # that ought to work too
1819         d.addCallback(lambda res: n.download_best_version())
1820         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1821         return d
1822
1823     def test_publish_all_servers_bad(self):
1824         # Break all servers: the publish should fail
1825         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1826         self.client = LessFakeClient(basedir, 20)
1827         for connection in self.client._connections.values():
1828             connection.broken = True
1829         d = self.shouldFail(NotEnoughServersError,
1830                             "test_publish_all_servers_bad",
1831                             "Ran out of non-bad servers",
1832                             self.client.create_mutable_file, "contents")
1833         return d
1834
1835     def test_publish_no_servers(self):
1836         # no servers at all: the publish should fail
1837         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1838         self.client = LessFakeClient(basedir, 0)
1839         d = self.shouldFail(NotEnoughServersError,
1840                             "test_publish_no_servers",
1841                             "Ran out of non-bad servers",
1842                             self.client.create_mutable_file, "contents")
1843         return d
1844     test_publish_no_servers.timeout = 30
1845
1846
1847     def test_privkey_query_error(self):
1848         # when a servermap is updated with MODE_WRITE, it tries to get the
1849         # privkey. Something might go wrong during this query attempt.
1850         self.client = FakeClient(20)
1851         # we need some contents that are large enough to push the privkey out
1852         # of the early part of the file
1853         LARGE = "These are Larger contents" * 200 # about 5KB
1854         d = self.client.create_mutable_file(LARGE)
1855         def _created(n):
1856             self.uri = n.get_uri()
1857             self.n2 = self.client.create_node_from_uri(self.uri)
1858             # we start by doing a map update to figure out which is the first
1859             # server.
1860             return n.get_servermap(MODE_WRITE)
1861         d.addCallback(_created)
1862         d.addCallback(lambda res: fireEventually(res))
1863         def _got_smap1(smap):
1864             peer0 = list(smap.make_sharemap()[0])[0]
1865             # we tell the server to respond to this peer first, so that it
1866             # will be asked for the privkey first
1867             self.client._storage._sequence = [peer0]
1868             # now we make the peer fail their second query
1869             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1870         d.addCallback(_got_smap1)
1871         # now we update a servermap from a new node (which doesn't have the
1872         # privkey yet, forcing it to use a separate privkey query). Each
1873         # query response will trigger a privkey query, and since we're using
1874         # _sequence to make the peer0 response come back first, we'll send it
1875         # a privkey query first, and _sequence will again ensure that the
1876         # peer0 query will also come back before the others, and then
1877         # _special_answers will make sure that the query raises an exception.
1878         # The whole point of these hijinks is to exercise the code in
1879         # _privkey_query_failed. Note that the map-update will succeed, since
1880         # we'll just get a copy from one of the other shares.
1881         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1882         # Using FakeStorage._sequence means there will be read requests still
1883         # floating around.. wait for them to retire
1884         def _cancel_timer(res):
1885             if self.client._storage._pending_timer:
1886                 self.client._storage._pending_timer.cancel()
1887             return res
1888         d.addBoth(_cancel_timer)
1889         return d
1890
1891     def test_privkey_query_missing(self):
1892         # like test_privkey_query_error, but the shares are deleted by the
1893         # second query, instead of raising an exception.
1894         self.client = FakeClient(20)
1895         LARGE = "These are Larger contents" * 200 # about 5KB
1896         d = self.client.create_mutable_file(LARGE)
1897         def _created(n):
1898             self.uri = n.get_uri()
1899             self.n2 = self.client.create_node_from_uri(self.uri)
1900             return n.get_servermap(MODE_WRITE)
1901         d.addCallback(_created)
1902         d.addCallback(lambda res: fireEventually(res))
1903         def _got_smap1(smap):
1904             peer0 = list(smap.make_sharemap()[0])[0]
1905             self.client._storage._sequence = [peer0]
1906             self.client._storage._special_answers[peer0] = ["normal", "none"]
1907         d.addCallback(_got_smap1)
1908         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1909         def _cancel_timer(res):
1910             if self.client._storage._pending_timer:
1911                 self.client._storage._pending_timer.cancel()
1912             return res
1913         d.addBoth(_cancel_timer)
1914         return d