]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
interfaces.py: promote immutable.encode.NotEnoughSharesError.. it isn't just for...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, storage
8 from allmydata.immutable import download
9 from allmydata.util import base32, testutil, idlib
10 from allmydata.util.idlib import shortnodeid_b2a
11 from allmydata.util.hashutil import tagged_hash
12 from allmydata.util.fileutil import make_dirs
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
14      FileTooLargeError, NotEnoughSharesError, IRepairResults
15 from allmydata.monitor import Monitor
16 from allmydata.test.common import ShouldFailMixin
17 from foolscap.eventual import eventually, fireEventually
18 from foolscap.logging import log
19 import sha
20
21 from allmydata.mutable.node import MutableFileNode, BackoffAgent
22 from allmydata.mutable.common import DictOfSets, ResponseCache, \
23      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
24      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
25      NotEnoughServersError, CorruptShareError
26 from allmydata.mutable.retrieve import Retrieve
27 from allmydata.mutable.publish import Publish
28 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
29 from allmydata.mutable.layout import unpack_header, unpack_share
30 from allmydata.mutable.repair import MustForceRepairError
31
32 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
33 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
34 # of this.
35
36 class FastMutableFileNode(MutableFileNode):
37     SIGNATURE_KEY_SIZE = 522
38
39 # this "FakeStorage" exists to put the share data in RAM and avoid using real
40 # network connections, both to speed up the tests and to reduce the amount of
41 # non-mutable.py code being exercised.
42
43 class FakeStorage:
44     # this class replaces the collection of storage servers, allowing the
45     # tests to examine and manipulate the published shares. It also lets us
46     # control the order in which read queries are answered, to exercise more
47     # of the error-handling code in Retrieve .
48     #
49     # Note that we ignore the storage index: this FakeStorage instance can
50     # only be used for a single storage index.
51
52
53     def __init__(self):
54         self._peers = {}
55         # _sequence is used to cause the responses to occur in a specific
56         # order. If it is in use, then we will defer queries instead of
57         # answering them right away, accumulating the Deferreds in a dict. We
58         # don't know exactly how many queries we'll get, so exactly one
59         # second after the first query arrives, we will release them all (in
60         # order).
61         self._sequence = None
62         self._pending = {}
63         self._pending_timer = None
64         self._special_answers = {}
65
66     def read(self, peerid, storage_index):
67         shares = self._peers.get(peerid, {})
68         if self._special_answers.get(peerid, []):
69             mode = self._special_answers[peerid].pop(0)
70             if mode == "fail":
71                 shares = failure.Failure(IntentionalError())
72             elif mode == "none":
73                 shares = {}
74             elif mode == "normal":
75                 pass
76         if self._sequence is None:
77             return defer.succeed(shares)
78         d = defer.Deferred()
79         if not self._pending:
80             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
81         self._pending[peerid] = (d, shares)
82         return d
83
84     def _fire_readers(self):
85         self._pending_timer = None
86         pending = self._pending
87         self._pending = {}
88         extra = []
89         for peerid in self._sequence:
90             if peerid in pending:
91                 d, shares = pending.pop(peerid)
92                 eventually(d.callback, shares)
93         for (d, shares) in pending.values():
94             eventually(d.callback, shares)
95
96     def write(self, peerid, storage_index, shnum, offset, data):
97         if peerid not in self._peers:
98             self._peers[peerid] = {}
99         shares = self._peers[peerid]
100         f = StringIO()
101         f.write(shares.get(shnum, ""))
102         f.seek(offset)
103         f.write(data)
104         shares[shnum] = f.getvalue()
105
106
107 class FakeStorageServer:
108     def __init__(self, peerid, storage):
109         self.peerid = peerid
110         self.storage = storage
111         self.queries = 0
112     def callRemote(self, methname, *args, **kwargs):
113         def _call():
114             meth = getattr(self, methname)
115             return meth(*args, **kwargs)
116         d = fireEventually()
117         d.addCallback(lambda res: _call())
118         return d
119     def callRemoteOnly(self, methname, *args, **kwargs):
120         d = self.callRemote(methname, *args, **kwargs)
121         d.addBoth(lambda ignore: None)
122         pass
123
124     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
125         pass
126
127     def slot_readv(self, storage_index, shnums, readv):
128         d = self.storage.read(self.peerid, storage_index)
129         def _read(shares):
130             response = {}
131             for shnum in shares:
132                 if shnums and shnum not in shnums:
133                     continue
134                 vector = response[shnum] = []
135                 for (offset, length) in readv:
136                     assert isinstance(offset, (int, long)), offset
137                     assert isinstance(length, (int, long)), length
138                     vector.append(shares[shnum][offset:offset+length])
139             return response
140         d.addCallback(_read)
141         return d
142
143     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
144                                         tw_vectors, read_vector):
145         # always-pass: parrot the test vectors back to them.
146         readv = {}
147         for shnum, (testv, writev, new_length) in tw_vectors.items():
148             for (offset, length, op, specimen) in testv:
149                 assert op in ("le", "eq", "ge")
150             # TODO: this isn't right, the read is controlled by read_vector,
151             # not by testv
152             readv[shnum] = [ specimen
153                              for (offset, length, op, specimen)
154                              in testv ]
155             for (offset, data) in writev:
156                 self.storage.write(self.peerid, storage_index, shnum,
157                                    offset, data)
158         answer = (True, readv)
159         return fireEventually(answer)
160
161
162 # our "FakeClient" has just enough functionality of the real Client to let
163 # the tests run.
164
165 class FakeClient:
166     mutable_file_node_class = FastMutableFileNode
167
168     def __init__(self, num_peers=10):
169         self._storage = FakeStorage()
170         self._num_peers = num_peers
171         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
172                          for i in range(self._num_peers)]
173         self._connections = dict([(peerid, FakeStorageServer(peerid,
174                                                              self._storage))
175                                   for peerid in self._peerids])
176         self.nodeid = "fakenodeid"
177
178     def log(self, msg, **kw):
179         return log.msg(msg, **kw)
180
181     def get_renewal_secret(self):
182         return "I hereby permit you to renew my files"
183     def get_cancel_secret(self):
184         return "I hereby permit you to cancel my leases"
185
186     def create_mutable_file(self, contents=""):
187         n = self.mutable_file_node_class(self)
188         d = n.create(contents)
189         d.addCallback(lambda res: n)
190         return d
191
192     def notify_retrieve(self, r):
193         pass
194     def notify_publish(self, p, size):
195         pass
196     def notify_mapupdate(self, u):
197         pass
198
199     def create_node_from_uri(self, u):
200         u = IURI(u)
201         assert IMutableFileURI.providedBy(u), u
202         res = self.mutable_file_node_class(self).init_from_uri(u)
203         return res
204
205     def get_permuted_peers(self, service_name, key):
206         """
207         @return: list of (peerid, connection,)
208         """
209         results = []
210         for (peerid, connection) in self._connections.items():
211             assert isinstance(peerid, str)
212             permuted = sha.new(key + peerid).digest()
213             results.append((permuted, peerid, connection))
214         results.sort()
215         results = [ (r[1],r[2]) for r in results]
216         return results
217
218     def upload(self, uploadable):
219         assert IUploadable.providedBy(uploadable)
220         d = uploadable.get_size()
221         d.addCallback(lambda length: uploadable.read(length))
222         #d.addCallback(self.create_mutable_file)
223         def _got_data(datav):
224             data = "".join(datav)
225             #newnode = FastMutableFileNode(self)
226             return uri.LiteralFileURI(data)
227         d.addCallback(_got_data)
228         return d
229
230
231 def flip_bit(original, byte_offset):
232     return (original[:byte_offset] +
233             chr(ord(original[byte_offset]) ^ 0x01) +
234             original[byte_offset+1:])
235
236 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
237     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
238     # list of shnums to corrupt.
239     for peerid in s._peers:
240         shares = s._peers[peerid]
241         for shnum in shares:
242             if (shnums_to_corrupt is not None
243                 and shnum not in shnums_to_corrupt):
244                 continue
245             data = shares[shnum]
246             (version,
247              seqnum,
248              root_hash,
249              IV,
250              k, N, segsize, datalen,
251              o) = unpack_header(data)
252             if isinstance(offset, tuple):
253                 offset1, offset2 = offset
254             else:
255                 offset1 = offset
256                 offset2 = 0
257             if offset1 == "pubkey":
258                 real_offset = 107
259             elif offset1 in o:
260                 real_offset = o[offset1]
261             else:
262                 real_offset = offset1
263             real_offset = int(real_offset) + offset2 + offset_offset
264             assert isinstance(real_offset, int), offset
265             shares[shnum] = flip_bit(data, real_offset)
266     return res
267
268 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
269     def setUp(self):
270         self.client = FakeClient()
271
272     def test_create(self):
273         d = self.client.create_mutable_file()
274         def _created(n):
275             self.failUnless(isinstance(n, FastMutableFileNode))
276             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
277             peer0 = self.client._peerids[0]
278             shnums = self.client._storage._peers[peer0].keys()
279             self.failUnlessEqual(len(shnums), 1)
280         d.addCallback(_created)
281         return d
282
283     def test_serialize(self):
284         n = MutableFileNode(self.client)
285         calls = []
286         def _callback(*args, **kwargs):
287             self.failUnlessEqual(args, (4,) )
288             self.failUnlessEqual(kwargs, {"foo": 5})
289             calls.append(1)
290             return 6
291         d = n._do_serialized(_callback, 4, foo=5)
292         def _check_callback(res):
293             self.failUnlessEqual(res, 6)
294             self.failUnlessEqual(calls, [1])
295         d.addCallback(_check_callback)
296
297         def _errback():
298             raise ValueError("heya")
299         d.addCallback(lambda res:
300                       self.shouldFail(ValueError, "_check_errback", "heya",
301                                       n._do_serialized, _errback))
302         return d
303
304     def test_upload_and_download(self):
305         d = self.client.create_mutable_file()
306         def _created(n):
307             d = defer.succeed(None)
308             d.addCallback(lambda res: n.get_servermap(MODE_READ))
309             d.addCallback(lambda smap: smap.dump(StringIO()))
310             d.addCallback(lambda sio:
311                           self.failUnless("3-of-10" in sio.getvalue()))
312             d.addCallback(lambda res: n.overwrite("contents 1"))
313             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
314             d.addCallback(lambda res: n.download_best_version())
315             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
316             d.addCallback(lambda res: n.get_size_of_best_version())
317             d.addCallback(lambda size:
318                           self.failUnlessEqual(size, len("contents 1")))
319             d.addCallback(lambda res: n.overwrite("contents 2"))
320             d.addCallback(lambda res: n.download_best_version())
321             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
322             d.addCallback(lambda res: n.download(download.Data()))
323             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
324             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
325             d.addCallback(lambda smap: n.upload("contents 3", smap))
326             d.addCallback(lambda res: n.download_best_version())
327             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
328             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
329             d.addCallback(lambda smap:
330                           n.download_version(smap,
331                                              smap.best_recoverable_version()))
332             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
333             # test a file that is large enough to overcome the
334             # mapupdate-to-retrieve data caching (i.e. make the shares larger
335             # than the default readsize, which is 2000 bytes). A 15kB file
336             # will have 5kB shares.
337             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
338             d.addCallback(lambda res: n.download_best_version())
339             d.addCallback(lambda res:
340                           self.failUnlessEqual(res, "large size file" * 1000))
341             return d
342         d.addCallback(_created)
343         return d
344
345     def test_create_with_initial_contents(self):
346         d = self.client.create_mutable_file("contents 1")
347         def _created(n):
348             d = n.download_best_version()
349             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
350             d.addCallback(lambda res: n.overwrite("contents 2"))
351             d.addCallback(lambda res: n.download_best_version())
352             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
353             return d
354         d.addCallback(_created)
355         return d
356
357     def test_create_with_too_large_contents(self):
358         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
359         d = self.shouldFail(FileTooLargeError, "too_large",
360                             "SDMF is limited to one segment, and %d > %d" %
361                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
362                             self.client.create_mutable_file, BIG)
363         d.addCallback(lambda res: self.client.create_mutable_file("small"))
364         def _created(n):
365             return self.shouldFail(FileTooLargeError, "too_large_2",
366                                    "SDMF is limited to one segment, and %d > %d" %
367                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
368                                    n.overwrite, BIG)
369         d.addCallback(_created)
370         return d
371
372     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
373         d = n.get_servermap(MODE_READ)
374         d.addCallback(lambda servermap: servermap.best_recoverable_version())
375         d.addCallback(lambda verinfo:
376                       self.failUnlessEqual(verinfo[0], expected_seqnum))
377         return d
378
379     def test_modify(self):
380         def _modifier(old_contents):
381             return old_contents + "line2"
382         def _non_modifier(old_contents):
383             return old_contents
384         def _none_modifier(old_contents):
385             return None
386         def _error_modifier(old_contents):
387             raise ValueError("oops")
388         def _toobig_modifier(old_contents):
389             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
390         calls = []
391         def _ucw_error_modifier(old_contents):
392             # simulate an UncoordinatedWriteError once
393             calls.append(1)
394             if len(calls) <= 1:
395                 raise UncoordinatedWriteError("simulated")
396             return old_contents + "line3"
397
398         d = self.client.create_mutable_file("line1")
399         def _created(n):
400             d = n.modify(_modifier)
401             d.addCallback(lambda res: n.download_best_version())
402             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
403             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
404
405             d.addCallback(lambda res: n.modify(_non_modifier))
406             d.addCallback(lambda res: n.download_best_version())
407             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
409
410             d.addCallback(lambda res: n.modify(_none_modifier))
411             d.addCallback(lambda res: n.download_best_version())
412             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
413             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
414
415             d.addCallback(lambda res:
416                           self.shouldFail(ValueError, "error_modifier", None,
417                                           n.modify, _error_modifier))
418             d.addCallback(lambda res: n.download_best_version())
419             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
420             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
421
422             d.addCallback(lambda res:
423                           self.shouldFail(FileTooLargeError, "toobig_modifier",
424                                           "SDMF is limited to one segment",
425                                           n.modify, _toobig_modifier))
426             d.addCallback(lambda res: n.download_best_version())
427             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
428             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
429
430             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
431             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
432             d.addCallback(lambda res: n.download_best_version())
433             d.addCallback(lambda res: self.failUnlessEqual(res,
434                                                            "line1line2line3"))
435             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
436
437             return d
438         d.addCallback(_created)
439         return d
440
441     def test_modify_backoffer(self):
442         def _modifier(old_contents):
443             return old_contents + "line2"
444         calls = []
445         def _ucw_error_modifier(old_contents):
446             # simulate an UncoordinatedWriteError once
447             calls.append(1)
448             if len(calls) <= 1:
449                 raise UncoordinatedWriteError("simulated")
450             return old_contents + "line3"
451         def _always_ucw_error_modifier(old_contents):
452             raise UncoordinatedWriteError("simulated")
453         def _backoff_stopper(node, f):
454             return f
455         def _backoff_pauser(node, f):
456             d = defer.Deferred()
457             reactor.callLater(0.5, d.callback, None)
458             return d
459
460         # the give-up-er will hit its maximum retry count quickly
461         giveuper = BackoffAgent()
462         giveuper._delay = 0.1
463         giveuper.factor = 1
464
465         d = self.client.create_mutable_file("line1")
466         def _created(n):
467             d = n.modify(_modifier)
468             d.addCallback(lambda res: n.download_best_version())
469             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
470             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
471
472             d.addCallback(lambda res:
473                           self.shouldFail(UncoordinatedWriteError,
474                                           "_backoff_stopper", None,
475                                           n.modify, _ucw_error_modifier,
476                                           _backoff_stopper))
477             d.addCallback(lambda res: n.download_best_version())
478             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
479             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
480
481             def _reset_ucw_error_modifier(res):
482                 calls[:] = []
483                 return res
484             d.addCallback(_reset_ucw_error_modifier)
485             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
486                                                _backoff_pauser))
487             d.addCallback(lambda res: n.download_best_version())
488             d.addCallback(lambda res: self.failUnlessEqual(res,
489                                                            "line1line2line3"))
490             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
491
492             d.addCallback(lambda res:
493                           self.shouldFail(UncoordinatedWriteError,
494                                           "giveuper", None,
495                                           n.modify, _always_ucw_error_modifier,
496                                           giveuper.delay))
497             d.addCallback(lambda res: n.download_best_version())
498             d.addCallback(lambda res: self.failUnlessEqual(res,
499                                                            "line1line2line3"))
500             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
501
502             return d
503         d.addCallback(_created)
504         return d
505
506     def test_upload_and_download_full_size_keys(self):
507         self.client.mutable_file_node_class = MutableFileNode
508         d = self.client.create_mutable_file()
509         def _created(n):
510             d = defer.succeed(None)
511             d.addCallback(lambda res: n.get_servermap(MODE_READ))
512             d.addCallback(lambda smap: smap.dump(StringIO()))
513             d.addCallback(lambda sio:
514                           self.failUnless("3-of-10" in sio.getvalue()))
515             d.addCallback(lambda res: n.overwrite("contents 1"))
516             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
517             d.addCallback(lambda res: n.download_best_version())
518             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
519             d.addCallback(lambda res: n.overwrite("contents 2"))
520             d.addCallback(lambda res: n.download_best_version())
521             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
522             d.addCallback(lambda res: n.download(download.Data()))
523             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
524             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
525             d.addCallback(lambda smap: n.upload("contents 3", smap))
526             d.addCallback(lambda res: n.download_best_version())
527             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
528             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
529             d.addCallback(lambda smap:
530                           n.download_version(smap,
531                                              smap.best_recoverable_version()))
532             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
533             return d
534         d.addCallback(_created)
535         return d
536
537
538 class MakeShares(unittest.TestCase):
539     def test_encrypt(self):
540         c = FakeClient()
541         fn = FastMutableFileNode(c)
542         CONTENTS = "some initial contents"
543         d = fn.create(CONTENTS)
544         def _created(res):
545             p = Publish(fn, None)
546             p.salt = "SALT" * 4
547             p.readkey = "\x00" * 16
548             p.newdata = CONTENTS
549             p.required_shares = 3
550             p.total_shares = 10
551             p.setup_encoding_parameters()
552             return p._encrypt_and_encode()
553         d.addCallback(_created)
554         def _done(shares_and_shareids):
555             (shares, share_ids) = shares_and_shareids
556             self.failUnlessEqual(len(shares), 10)
557             for sh in shares:
558                 self.failUnless(isinstance(sh, str))
559                 self.failUnlessEqual(len(sh), 7)
560             self.failUnlessEqual(len(share_ids), 10)
561         d.addCallback(_done)
562         return d
563
564     def test_generate(self):
565         c = FakeClient()
566         fn = FastMutableFileNode(c)
567         CONTENTS = "some initial contents"
568         d = fn.create(CONTENTS)
569         def _created(res):
570             p = Publish(fn, None)
571             self._p = p
572             p.newdata = CONTENTS
573             p.required_shares = 3
574             p.total_shares = 10
575             p.setup_encoding_parameters()
576             p._new_seqnum = 3
577             p.salt = "SALT" * 4
578             # make some fake shares
579             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
580             p._privkey = fn.get_privkey()
581             p._encprivkey = fn.get_encprivkey()
582             p._pubkey = fn.get_pubkey()
583             return p._generate_shares(shares_and_ids)
584         d.addCallback(_created)
585         def _generated(res):
586             p = self._p
587             final_shares = p.shares
588             root_hash = p.root_hash
589             self.failUnlessEqual(len(root_hash), 32)
590             self.failUnless(isinstance(final_shares, dict))
591             self.failUnlessEqual(len(final_shares), 10)
592             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
593             for i,sh in final_shares.items():
594                 self.failUnless(isinstance(sh, str))
595                 # feed the share through the unpacker as a sanity-check
596                 pieces = unpack_share(sh)
597                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
598                  pubkey, signature, share_hash_chain, block_hash_tree,
599                  share_data, enc_privkey) = pieces
600                 self.failUnlessEqual(u_seqnum, 3)
601                 self.failUnlessEqual(u_root_hash, root_hash)
602                 self.failUnlessEqual(k, 3)
603                 self.failUnlessEqual(N, 10)
604                 self.failUnlessEqual(segsize, 21)
605                 self.failUnlessEqual(datalen, len(CONTENTS))
606                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
607                 sig_material = struct.pack(">BQ32s16s BBQQ",
608                                            0, p._new_seqnum, root_hash, IV,
609                                            k, N, segsize, datalen)
610                 self.failUnless(p._pubkey.verify(sig_material, signature))
611                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
612                 self.failUnless(isinstance(share_hash_chain, dict))
613                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
614                 for shnum,share_hash in share_hash_chain.items():
615                     self.failUnless(isinstance(shnum, int))
616                     self.failUnless(isinstance(share_hash, str))
617                     self.failUnlessEqual(len(share_hash), 32)
618                 self.failUnless(isinstance(block_hash_tree, list))
619                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
620                 self.failUnlessEqual(IV, "SALT"*4)
621                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
622                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
623         d.addCallback(_generated)
624         return d
625
626     # TODO: when we publish to 20 peers, we should get one share per peer on 10
627     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
628     # when we publish to zero peers, we should get a NotEnoughSharesError
629
630 class PublishMixin:
631     def publish_one(self):
632         # publish a file and create shares, which can then be manipulated
633         # later.
634         self.CONTENTS = "New contents go here" * 1000
635         num_peers = 20
636         self._client = FakeClient(num_peers)
637         self._storage = self._client._storage
638         d = self._client.create_mutable_file(self.CONTENTS)
639         def _created(node):
640             self._fn = node
641             self._fn2 = self._client.create_node_from_uri(node.get_uri())
642         d.addCallback(_created)
643         return d
644     def publish_multiple(self):
645         self.CONTENTS = ["Contents 0",
646                          "Contents 1",
647                          "Contents 2",
648                          "Contents 3a",
649                          "Contents 3b"]
650         self._copied_shares = {}
651         num_peers = 20
652         self._client = FakeClient(num_peers)
653         self._storage = self._client._storage
654         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
655         def _created(node):
656             self._fn = node
657             # now create multiple versions of the same file, and accumulate
658             # their shares, so we can mix and match them later.
659             d = defer.succeed(None)
660             d.addCallback(self._copy_shares, 0)
661             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
662             d.addCallback(self._copy_shares, 1)
663             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
664             d.addCallback(self._copy_shares, 2)
665             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
666             d.addCallback(self._copy_shares, 3)
667             # now we replace all the shares with version s3, and upload a new
668             # version to get s4b.
669             rollback = dict([(i,2) for i in range(10)])
670             d.addCallback(lambda res: self._set_versions(rollback))
671             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
672             d.addCallback(self._copy_shares, 4)
673             # we leave the storage in state 4
674             return d
675         d.addCallback(_created)
676         return d
677
678     def _copy_shares(self, ignored, index):
679         shares = self._client._storage._peers
680         # we need a deep copy
681         new_shares = {}
682         for peerid in shares:
683             new_shares[peerid] = {}
684             for shnum in shares[peerid]:
685                 new_shares[peerid][shnum] = shares[peerid][shnum]
686         self._copied_shares[index] = new_shares
687
688     def _set_versions(self, versionmap):
689         # versionmap maps shnums to which version (0,1,2,3,4) we want the
690         # share to be at. Any shnum which is left out of the map will stay at
691         # its current version.
692         shares = self._client._storage._peers
693         oldshares = self._copied_shares
694         for peerid in shares:
695             for shnum in shares[peerid]:
696                 if shnum in versionmap:
697                     index = versionmap[shnum]
698                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
699
700
701 class Servermap(unittest.TestCase, PublishMixin):
702     def setUp(self):
703         return self.publish_one()
704
705     def make_servermap(self, mode=MODE_CHECK, fn=None):
706         if fn is None:
707             fn = self._fn
708         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
709         d = smu.update()
710         return d
711
712     def update_servermap(self, oldmap, mode=MODE_CHECK):
713         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
714         d = smu.update()
715         return d
716
717     def failUnlessOneRecoverable(self, sm, num_shares):
718         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
719         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
720         best = sm.best_recoverable_version()
721         self.failIfEqual(best, None)
722         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
723         self.failUnlessEqual(len(sm.shares_available()), 1)
724         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
725         shnum, peerids = sm.make_sharemap().items()[0]
726         peerid = list(peerids)[0]
727         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
728         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
729         return sm
730
731     def test_basic(self):
732         d = defer.succeed(None)
733         ms = self.make_servermap
734         us = self.update_servermap
735
736         d.addCallback(lambda res: ms(mode=MODE_CHECK))
737         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
738         d.addCallback(lambda res: ms(mode=MODE_WRITE))
739         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
740         d.addCallback(lambda res: ms(mode=MODE_READ))
741         # this more stops at k+epsilon, and epsilon=k, so 6 shares
742         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
743         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
744         # this mode stops at 'k' shares
745         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
746
747         # and can we re-use the same servermap? Note that these are sorted in
748         # increasing order of number of servers queried, since once a server
749         # gets into the servermap, we'll always ask it for an update.
750         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
751         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
752         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
753         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
754         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
755         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
756         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
757         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
758         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
759
760         return d
761
762     def test_fetch_privkey(self):
763         d = defer.succeed(None)
764         # use the sibling filenode (which hasn't been used yet), and make
765         # sure it can fetch the privkey. The file is small, so the privkey
766         # will be fetched on the first (query) pass.
767         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
768         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
769
770         # create a new file, which is large enough to knock the privkey out
771         # of the early part of the file
772         LARGE = "These are Larger contents" * 200 # about 5KB
773         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
774         def _created(large_fn):
775             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
776             return self.make_servermap(MODE_WRITE, large_fn2)
777         d.addCallback(_created)
778         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
779         return d
780
781     def test_mark_bad(self):
782         d = defer.succeed(None)
783         ms = self.make_servermap
784         us = self.update_servermap
785
786         d.addCallback(lambda res: ms(mode=MODE_READ))
787         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
788         def _made_map(sm):
789             v = sm.best_recoverable_version()
790             vm = sm.make_versionmap()
791             shares = list(vm[v])
792             self.failUnlessEqual(len(shares), 6)
793             self._corrupted = set()
794             # mark the first 5 shares as corrupt, then update the servermap.
795             # The map should not have the marked shares it in any more, and
796             # new shares should be found to replace the missing ones.
797             for (shnum, peerid, timestamp) in shares:
798                 if shnum < 5:
799                     self._corrupted.add( (peerid, shnum) )
800                     sm.mark_bad_share(peerid, shnum, "")
801             return self.update_servermap(sm, MODE_WRITE)
802         d.addCallback(_made_map)
803         def _check_map(sm):
804             # this should find all 5 shares that weren't marked bad
805             v = sm.best_recoverable_version()
806             vm = sm.make_versionmap()
807             shares = list(vm[v])
808             for (peerid, shnum) in self._corrupted:
809                 peer_shares = sm.shares_on_peer(peerid)
810                 self.failIf(shnum in peer_shares,
811                             "%d was in %s" % (shnum, peer_shares))
812             self.failUnlessEqual(len(shares), 5)
813         d.addCallback(_check_map)
814         return d
815
816     def failUnlessNoneRecoverable(self, sm):
817         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
818         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
819         best = sm.best_recoverable_version()
820         self.failUnlessEqual(best, None)
821         self.failUnlessEqual(len(sm.shares_available()), 0)
822
823     def test_no_shares(self):
824         self._client._storage._peers = {} # delete all shares
825         ms = self.make_servermap
826         d = defer.succeed(None)
827
828         d.addCallback(lambda res: ms(mode=MODE_CHECK))
829         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
830
831         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
832         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
833
834         d.addCallback(lambda res: ms(mode=MODE_WRITE))
835         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
836
837         d.addCallback(lambda res: ms(mode=MODE_READ))
838         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
839
840         return d
841
842     def failUnlessNotQuiteEnough(self, sm):
843         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
844         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
845         best = sm.best_recoverable_version()
846         self.failUnlessEqual(best, None)
847         self.failUnlessEqual(len(sm.shares_available()), 1)
848         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
849         return sm
850
851     def test_not_quite_enough_shares(self):
852         s = self._client._storage
853         ms = self.make_servermap
854         num_shares = len(s._peers)
855         for peerid in s._peers:
856             s._peers[peerid] = {}
857             num_shares -= 1
858             if num_shares == 2:
859                 break
860         # now there ought to be only two shares left
861         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
862
863         d = defer.succeed(None)
864
865         d.addCallback(lambda res: ms(mode=MODE_CHECK))
866         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
867         d.addCallback(lambda sm:
868                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
869         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
870         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
871         d.addCallback(lambda res: ms(mode=MODE_WRITE))
872         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
873         d.addCallback(lambda res: ms(mode=MODE_READ))
874         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
875
876         return d
877
878
879
880 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
881     def setUp(self):
882         return self.publish_one()
883
884     def make_servermap(self, mode=MODE_READ, oldmap=None):
885         if oldmap is None:
886             oldmap = ServerMap()
887         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
888         d = smu.update()
889         return d
890
891     def abbrev_verinfo(self, verinfo):
892         if verinfo is None:
893             return None
894         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
895          offsets_tuple) = verinfo
896         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
897
898     def abbrev_verinfo_dict(self, verinfo_d):
899         output = {}
900         for verinfo,value in verinfo_d.items():
901             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
902              offsets_tuple) = verinfo
903             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
904         return output
905
906     def dump_servermap(self, servermap):
907         print "SERVERMAP", servermap
908         print "RECOVERABLE", [self.abbrev_verinfo(v)
909                               for v in servermap.recoverable_versions()]
910         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
911         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
912
913     def do_download(self, servermap, version=None):
914         if version is None:
915             version = servermap.best_recoverable_version()
916         r = Retrieve(self._fn, servermap, version)
917         return r.download()
918
919     def test_basic(self):
920         d = self.make_servermap()
921         def _do_retrieve(servermap):
922             self._smap = servermap
923             #self.dump_servermap(servermap)
924             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
925             return self.do_download(servermap)
926         d.addCallback(_do_retrieve)
927         def _retrieved(new_contents):
928             self.failUnlessEqual(new_contents, self.CONTENTS)
929         d.addCallback(_retrieved)
930         # we should be able to re-use the same servermap, both with and
931         # without updating it.
932         d.addCallback(lambda res: self.do_download(self._smap))
933         d.addCallback(_retrieved)
934         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
935         d.addCallback(lambda res: self.do_download(self._smap))
936         d.addCallback(_retrieved)
937         # clobbering the pubkey should make the servermap updater re-fetch it
938         def _clobber_pubkey(res):
939             self._fn._pubkey = None
940         d.addCallback(_clobber_pubkey)
941         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
942         d.addCallback(lambda res: self.do_download(self._smap))
943         d.addCallback(_retrieved)
944         return d
945
946     def test_all_shares_vanished(self):
947         d = self.make_servermap()
948         def _remove_shares(servermap):
949             for shares in self._storage._peers.values():
950                 shares.clear()
951             d1 = self.shouldFail(NotEnoughSharesError,
952                                  "test_all_shares_vanished",
953                                  "ran out of peers",
954                                  self.do_download, servermap)
955             return d1
956         d.addCallback(_remove_shares)
957         return d
958
959     def test_no_servers(self):
960         c2 = FakeClient(0)
961         self._fn._client = c2
962         # if there are no servers, then a MODE_READ servermap should come
963         # back empty
964         d = self.make_servermap()
965         def _check_servermap(servermap):
966             self.failUnlessEqual(servermap.best_recoverable_version(), None)
967             self.failIf(servermap.recoverable_versions())
968             self.failIf(servermap.unrecoverable_versions())
969             self.failIf(servermap.all_peers())
970         d.addCallback(_check_servermap)
971         return d
972     test_no_servers.timeout = 15
973
974     def test_no_servers_download(self):
975         c2 = FakeClient(0)
976         self._fn._client = c2
977         d = self.shouldFail(UnrecoverableFileError,
978                             "test_no_servers_download",
979                             "no recoverable versions",
980                             self._fn.download_best_version)
981         def _restore(res):
982             # a failed download that occurs while we aren't connected to
983             # anybody should not prevent a subsequent download from working.
984             # This isn't quite the webapi-driven test that #463 wants, but it
985             # should be close enough.
986             self._fn._client = self._client
987             return self._fn.download_best_version()
988         def _retrieved(new_contents):
989             self.failUnlessEqual(new_contents, self.CONTENTS)
990         d.addCallback(_restore)
991         d.addCallback(_retrieved)
992         return d
993     test_no_servers_download.timeout = 15
994
995     def _test_corrupt_all(self, offset, substring,
996                           should_succeed=False, corrupt_early=True,
997                           failure_checker=None):
998         d = defer.succeed(None)
999         if corrupt_early:
1000             d.addCallback(corrupt, self._storage, offset)
1001         d.addCallback(lambda res: self.make_servermap())
1002         if not corrupt_early:
1003             d.addCallback(corrupt, self._storage, offset)
1004         def _do_retrieve(servermap):
1005             ver = servermap.best_recoverable_version()
1006             if ver is None and not should_succeed:
1007                 # no recoverable versions == not succeeding. The problem
1008                 # should be noted in the servermap's list of problems.
1009                 if substring:
1010                     allproblems = [str(f) for f in servermap.problems]
1011                     self.failUnless(substring in "".join(allproblems))
1012                 return servermap
1013             if should_succeed:
1014                 d1 = self._fn.download_version(servermap, ver)
1015                 d1.addCallback(lambda new_contents:
1016                                self.failUnlessEqual(new_contents, self.CONTENTS))
1017             else:
1018                 d1 = self.shouldFail(NotEnoughSharesError,
1019                                      "_corrupt_all(offset=%s)" % (offset,),
1020                                      substring,
1021                                      self._fn.download_version, servermap, ver)
1022             if failure_checker:
1023                 d1.addCallback(failure_checker)
1024             d1.addCallback(lambda res: servermap)
1025             return d1
1026         d.addCallback(_do_retrieve)
1027         return d
1028
1029     def test_corrupt_all_verbyte(self):
1030         # when the version byte is not 0, we hit an assertion error in
1031         # unpack_share().
1032         d = self._test_corrupt_all(0, "AssertionError")
1033         def _check_servermap(servermap):
1034             # and the dump should mention the problems
1035             s = StringIO()
1036             dump = servermap.dump(s).getvalue()
1037             self.failUnless("10 PROBLEMS" in dump, dump)
1038         d.addCallback(_check_servermap)
1039         return d
1040
1041     def test_corrupt_all_seqnum(self):
1042         # a corrupt sequence number will trigger a bad signature
1043         return self._test_corrupt_all(1, "signature is invalid")
1044
1045     def test_corrupt_all_R(self):
1046         # a corrupt root hash will trigger a bad signature
1047         return self._test_corrupt_all(9, "signature is invalid")
1048
1049     def test_corrupt_all_IV(self):
1050         # a corrupt salt/IV will trigger a bad signature
1051         return self._test_corrupt_all(41, "signature is invalid")
1052
1053     def test_corrupt_all_k(self):
1054         # a corrupt 'k' will trigger a bad signature
1055         return self._test_corrupt_all(57, "signature is invalid")
1056
1057     def test_corrupt_all_N(self):
1058         # a corrupt 'N' will trigger a bad signature
1059         return self._test_corrupt_all(58, "signature is invalid")
1060
1061     def test_corrupt_all_segsize(self):
1062         # a corrupt segsize will trigger a bad signature
1063         return self._test_corrupt_all(59, "signature is invalid")
1064
1065     def test_corrupt_all_datalen(self):
1066         # a corrupt data length will trigger a bad signature
1067         return self._test_corrupt_all(67, "signature is invalid")
1068
1069     def test_corrupt_all_pubkey(self):
1070         # a corrupt pubkey won't match the URI's fingerprint. We need to
1071         # remove the pubkey from the filenode, or else it won't bother trying
1072         # to update it.
1073         self._fn._pubkey = None
1074         return self._test_corrupt_all("pubkey",
1075                                       "pubkey doesn't match fingerprint")
1076
1077     def test_corrupt_all_sig(self):
1078         # a corrupt signature is a bad one
1079         # the signature runs from about [543:799], depending upon the length
1080         # of the pubkey
1081         return self._test_corrupt_all("signature", "signature is invalid")
1082
1083     def test_corrupt_all_share_hash_chain_number(self):
1084         # a corrupt share hash chain entry will show up as a bad hash. If we
1085         # mangle the first byte, that will look like a bad hash number,
1086         # causing an IndexError
1087         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1088
1089     def test_corrupt_all_share_hash_chain_hash(self):
1090         # a corrupt share hash chain entry will show up as a bad hash. If we
1091         # mangle a few bytes in, that will look like a bad hash.
1092         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1093
1094     def test_corrupt_all_block_hash_tree(self):
1095         return self._test_corrupt_all("block_hash_tree",
1096                                       "block hash tree failure")
1097
1098     def test_corrupt_all_block(self):
1099         return self._test_corrupt_all("share_data", "block hash tree failure")
1100
1101     def test_corrupt_all_encprivkey(self):
1102         # a corrupted privkey won't even be noticed by the reader, only by a
1103         # writer.
1104         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1105
1106
1107     def test_corrupt_all_seqnum_late(self):
1108         # corrupting the seqnum between mapupdate and retrieve should result
1109         # in NotEnoughSharesError, since each share will look invalid
1110         def _check(res):
1111             f = res[0]
1112             self.failUnless(f.check(NotEnoughSharesError))
1113             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1114         return self._test_corrupt_all(1, "ran out of peers",
1115                                       corrupt_early=False,
1116                                       failure_checker=_check)
1117
1118     def test_corrupt_all_block_hash_tree_late(self):
1119         def _check(res):
1120             f = res[0]
1121             self.failUnless(f.check(NotEnoughSharesError))
1122         return self._test_corrupt_all("block_hash_tree",
1123                                       "block hash tree failure",
1124                                       corrupt_early=False,
1125                                       failure_checker=_check)
1126
1127
1128     def test_corrupt_all_block_late(self):
1129         def _check(res):
1130             f = res[0]
1131             self.failUnless(f.check(NotEnoughSharesError))
1132         return self._test_corrupt_all("share_data", "block hash tree failure",
1133                                       corrupt_early=False,
1134                                       failure_checker=_check)
1135
1136
1137     def test_basic_pubkey_at_end(self):
1138         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1139         # download to succeed but forcing a bunch of retries first. Note that
1140         # this is rather pessimistic: our Retrieve process will throw away
1141         # the whole share if the pubkey is bad, even though the rest of the
1142         # share might be good.
1143
1144         self._fn._pubkey = None
1145         k = self._fn.get_required_shares()
1146         N = self._fn.get_total_shares()
1147         d = defer.succeed(None)
1148         d.addCallback(corrupt, self._storage, "pubkey",
1149                       shnums_to_corrupt=range(0, N-k))
1150         d.addCallback(lambda res: self.make_servermap())
1151         def _do_retrieve(servermap):
1152             self.failUnless(servermap.problems)
1153             self.failUnless("pubkey doesn't match fingerprint"
1154                             in str(servermap.problems[0]))
1155             ver = servermap.best_recoverable_version()
1156             r = Retrieve(self._fn, servermap, ver)
1157             return r.download()
1158         d.addCallback(_do_retrieve)
1159         d.addCallback(lambda new_contents:
1160                       self.failUnlessEqual(new_contents, self.CONTENTS))
1161         return d
1162
1163     def test_corrupt_some(self):
1164         # corrupt the data of first five shares (so the servermap thinks
1165         # they're good but retrieve marks them as bad), so that the
1166         # MODE_READ set of 6 will be insufficient, forcing node.download to
1167         # retry with more servers.
1168         corrupt(None, self._storage, "share_data", range(5))
1169         d = self.make_servermap()
1170         def _do_retrieve(servermap):
1171             ver = servermap.best_recoverable_version()
1172             self.failUnless(ver)
1173             return self._fn.download_best_version()
1174         d.addCallback(_do_retrieve)
1175         d.addCallback(lambda new_contents:
1176                       self.failUnlessEqual(new_contents, self.CONTENTS))
1177         return d
1178
1179     def test_download_fails(self):
1180         corrupt(None, self._storage, "signature")
1181         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1182                             "no recoverable versions",
1183                             self._fn.download_best_version)
1184         return d
1185
1186
1187 class CheckerMixin:
1188     def check_good(self, r, where):
1189         self.failUnless(r.is_healthy(), where)
1190         return r
1191
1192     def check_bad(self, r, where):
1193         self.failIf(r.is_healthy(), where)
1194         return r
1195
1196     def check_expected_failure(self, r, expected_exception, substring, where):
1197         for (peerid, storage_index, shnum, f) in r.problems:
1198             if f.check(expected_exception):
1199                 self.failUnless(substring in str(f),
1200                                 "%s: substring '%s' not in '%s'" %
1201                                 (where, substring, str(f)))
1202                 return
1203         self.fail("%s: didn't see expected exception %s in problems %s" %
1204                   (where, expected_exception, r.problems))
1205
1206
1207 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1208     def setUp(self):
1209         return self.publish_one()
1210
1211
1212     def test_check_good(self):
1213         d = self._fn.check(Monitor())
1214         d.addCallback(self.check_good, "test_check_good")
1215         return d
1216
1217     def test_check_no_shares(self):
1218         for shares in self._storage._peers.values():
1219             shares.clear()
1220         d = self._fn.check(Monitor())
1221         d.addCallback(self.check_bad, "test_check_no_shares")
1222         return d
1223
1224     def test_check_not_enough_shares(self):
1225         for shares in self._storage._peers.values():
1226             for shnum in shares.keys():
1227                 if shnum > 0:
1228                     del shares[shnum]
1229         d = self._fn.check(Monitor())
1230         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1231         return d
1232
1233     def test_check_all_bad_sig(self):
1234         corrupt(None, self._storage, 1) # bad sig
1235         d = self._fn.check(Monitor())
1236         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1237         return d
1238
1239     def test_check_all_bad_blocks(self):
1240         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1241         # the Checker won't notice this.. it doesn't look at actual data
1242         d = self._fn.check(Monitor())
1243         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1244         return d
1245
1246     def test_verify_good(self):
1247         d = self._fn.check(Monitor(), verify=True)
1248         d.addCallback(self.check_good, "test_verify_good")
1249         return d
1250
1251     def test_verify_all_bad_sig(self):
1252         corrupt(None, self._storage, 1) # bad sig
1253         d = self._fn.check(Monitor(), verify=True)
1254         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1255         return d
1256
1257     def test_verify_one_bad_sig(self):
1258         corrupt(None, self._storage, 1, [9]) # bad sig
1259         d = self._fn.check(Monitor(), verify=True)
1260         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1261         return d
1262
1263     def test_verify_one_bad_block(self):
1264         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1265         # the Verifier *will* notice this, since it examines every byte
1266         d = self._fn.check(Monitor(), verify=True)
1267         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1268         d.addCallback(self.check_expected_failure,
1269                       CorruptShareError, "block hash tree failure",
1270                       "test_verify_one_bad_block")
1271         return d
1272
1273     def test_verify_one_bad_sharehash(self):
1274         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1275         d = self._fn.check(Monitor(), verify=True)
1276         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1277         d.addCallback(self.check_expected_failure,
1278                       CorruptShareError, "corrupt hashes",
1279                       "test_verify_one_bad_sharehash")
1280         return d
1281
1282     def test_verify_one_bad_encprivkey(self):
1283         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1284         d = self._fn.check(Monitor(), verify=True)
1285         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1286         d.addCallback(self.check_expected_failure,
1287                       CorruptShareError, "invalid privkey",
1288                       "test_verify_one_bad_encprivkey")
1289         return d
1290
1291     def test_verify_one_bad_encprivkey_uncheckable(self):
1292         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1293         readonly_fn = self._fn.get_readonly()
1294         # a read-only node has no way to validate the privkey
1295         d = readonly_fn.check(Monitor(), verify=True)
1296         d.addCallback(self.check_good,
1297                       "test_verify_one_bad_encprivkey_uncheckable")
1298         return d
1299
1300 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1301
1302     def get_shares(self, s):
1303         all_shares = {} # maps (peerid, shnum) to share data
1304         for peerid in s._peers:
1305             shares = s._peers[peerid]
1306             for shnum in shares:
1307                 data = shares[shnum]
1308                 all_shares[ (peerid, shnum) ] = data
1309         return all_shares
1310
1311     def copy_shares(self, ignored=None):
1312         self.old_shares.append(self.get_shares(self._storage))
1313
1314     def test_repair_nop(self):
1315         self.old_shares = []
1316         d = self.publish_one()
1317         d.addCallback(self.copy_shares)
1318         d.addCallback(lambda res: self._fn.check(Monitor()))
1319         d.addCallback(lambda check_results: self._fn.repair(check_results))
1320         def _check_results(rres):
1321             self.failUnless(IRepairResults.providedBy(rres))
1322             # TODO: examine results
1323
1324             self.copy_shares()
1325
1326             initial_shares = self.old_shares[0]
1327             new_shares = self.old_shares[1]
1328             # TODO: this really shouldn't change anything. When we implement
1329             # a "minimal-bandwidth" repairer", change this test to assert:
1330             #self.failUnlessEqual(new_shares, initial_shares)
1331
1332             # all shares should be in the same place as before
1333             self.failUnlessEqual(set(initial_shares.keys()),
1334                                  set(new_shares.keys()))
1335             # but they should all be at a newer seqnum. The IV will be
1336             # different, so the roothash will be too.
1337             for key in initial_shares:
1338                 (version0,
1339                  seqnum0,
1340                  root_hash0,
1341                  IV0,
1342                  k0, N0, segsize0, datalen0,
1343                  o0) = unpack_header(initial_shares[key])
1344                 (version1,
1345                  seqnum1,
1346                  root_hash1,
1347                  IV1,
1348                  k1, N1, segsize1, datalen1,
1349                  o1) = unpack_header(new_shares[key])
1350                 self.failUnlessEqual(version0, version1)
1351                 self.failUnlessEqual(seqnum0+1, seqnum1)
1352                 self.failUnlessEqual(k0, k1)
1353                 self.failUnlessEqual(N0, N1)
1354                 self.failUnlessEqual(segsize0, segsize1)
1355                 self.failUnlessEqual(datalen0, datalen1)
1356         d.addCallback(_check_results)
1357         return d
1358
1359     def failIfSharesChanged(self, ignored=None):
1360         old_shares = self.old_shares[-2]
1361         current_shares = self.old_shares[-1]
1362         self.failUnlessEqual(old_shares, current_shares)
1363
1364     def test_merge(self):
1365         self.old_shares = []
1366         d = self.publish_multiple()
1367         # repair will refuse to merge multiple highest seqnums unless you
1368         # pass force=True
1369         d.addCallback(lambda res:
1370                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1371                                           1:4,3:4,5:4,7:4,9:4}))
1372         d.addCallback(self.copy_shares)
1373         d.addCallback(lambda res: self._fn.check(Monitor()))
1374         def _try_repair(check_results):
1375             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1376             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1377                                  self._fn.repair, check_results)
1378             d2.addCallback(self.copy_shares)
1379             d2.addCallback(self.failIfSharesChanged)
1380             d2.addCallback(lambda res: check_results)
1381             return d2
1382         d.addCallback(_try_repair)
1383         d.addCallback(lambda check_results:
1384                       self._fn.repair(check_results, force=True))
1385         # this should give us 10 shares of the highest roothash
1386         def _check_repair_results(rres):
1387             pass # TODO
1388         d.addCallback(_check_repair_results)
1389         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1390         def _check_smap(smap):
1391             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1392             self.failIf(smap.unrecoverable_versions())
1393             # now, which should have won?
1394             roothash_s4a = self.get_roothash_for(3)
1395             roothash_s4b = self.get_roothash_for(4)
1396             if roothash_s4b > roothash_s4a:
1397                 expected_contents = self.CONTENTS[4]
1398             else:
1399                 expected_contents = self.CONTENTS[3]
1400             new_versionid = smap.best_recoverable_version()
1401             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1402             d2 = self._fn.download_version(smap, new_versionid)
1403             d2.addCallback(self.failUnlessEqual, expected_contents)
1404             return d2
1405         d.addCallback(_check_smap)
1406         return d
1407
1408     def test_non_merge(self):
1409         self.old_shares = []
1410         d = self.publish_multiple()
1411         # repair should not refuse a repair that doesn't need to merge. In
1412         # this case, we combine v2 with v3. The repair should ignore v2 and
1413         # copy v3 into a new v5.
1414         d.addCallback(lambda res:
1415                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1416                                           1:3,3:3,5:3,7:3,9:3}))
1417         d.addCallback(lambda res: self._fn.check(Monitor()))
1418         d.addCallback(lambda check_results: self._fn.repair(check_results))
1419         # this should give us 10 shares of v3
1420         def _check_repair_results(rres):
1421             pass # TODO
1422         d.addCallback(_check_repair_results)
1423         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1424         def _check_smap(smap):
1425             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1426             self.failIf(smap.unrecoverable_versions())
1427             # now, which should have won?
1428             roothash_s4a = self.get_roothash_for(3)
1429             expected_contents = self.CONTENTS[3]
1430             new_versionid = smap.best_recoverable_version()
1431             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1432             d2 = self._fn.download_version(smap, new_versionid)
1433             d2.addCallback(self.failUnlessEqual, expected_contents)
1434             return d2
1435         d.addCallback(_check_smap)
1436         return d
1437
1438     def get_roothash_for(self, index):
1439         # return the roothash for the first share we see in the saved set
1440         shares = self._copied_shares[index]
1441         for peerid in shares:
1442             for shnum in shares[peerid]:
1443                 share = shares[peerid][shnum]
1444                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1445                           unpack_header(share)
1446                 return root_hash
1447
1448 class MultipleEncodings(unittest.TestCase):
1449     def setUp(self):
1450         self.CONTENTS = "New contents go here"
1451         num_peers = 20
1452         self._client = FakeClient(num_peers)
1453         self._storage = self._client._storage
1454         d = self._client.create_mutable_file(self.CONTENTS)
1455         def _created(node):
1456             self._fn = node
1457         d.addCallback(_created)
1458         return d
1459
1460     def _encode(self, k, n, data):
1461         # encode 'data' into a peerid->shares dict.
1462
1463         fn2 = FastMutableFileNode(self._client)
1464         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1465         # and _fingerprint
1466         fn = self._fn
1467         fn2.init_from_uri(fn.get_uri())
1468         # then we copy over other fields that are normally fetched from the
1469         # existing shares
1470         fn2._pubkey = fn._pubkey
1471         fn2._privkey = fn._privkey
1472         fn2._encprivkey = fn._encprivkey
1473         # and set the encoding parameters to something completely different
1474         fn2._required_shares = k
1475         fn2._total_shares = n
1476
1477         s = self._client._storage
1478         s._peers = {} # clear existing storage
1479         p2 = Publish(fn2, None)
1480         d = p2.publish(data)
1481         def _published(res):
1482             shares = s._peers
1483             s._peers = {}
1484             return shares
1485         d.addCallback(_published)
1486         return d
1487
1488     def make_servermap(self, mode=MODE_READ, oldmap=None):
1489         if oldmap is None:
1490             oldmap = ServerMap()
1491         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1492         d = smu.update()
1493         return d
1494
1495     def test_multiple_encodings(self):
1496         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1497         # then mix up the shares, to make sure that download survives seeing
1498         # a variety of encodings. This is actually kind of tricky to set up.
1499
1500         contents1 = "Contents for encoding 1 (3-of-10) go here"
1501         contents2 = "Contents for encoding 2 (4-of-9) go here"
1502         contents3 = "Contents for encoding 3 (4-of-7) go here"
1503
1504         # we make a retrieval object that doesn't know what encoding
1505         # parameters to use
1506         fn3 = FastMutableFileNode(self._client)
1507         fn3.init_from_uri(self._fn.get_uri())
1508
1509         # now we upload a file through fn1, and grab its shares
1510         d = self._encode(3, 10, contents1)
1511         def _encoded_1(shares):
1512             self._shares1 = shares
1513         d.addCallback(_encoded_1)
1514         d.addCallback(lambda res: self._encode(4, 9, contents2))
1515         def _encoded_2(shares):
1516             self._shares2 = shares
1517         d.addCallback(_encoded_2)
1518         d.addCallback(lambda res: self._encode(4, 7, contents3))
1519         def _encoded_3(shares):
1520             self._shares3 = shares
1521         d.addCallback(_encoded_3)
1522
1523         def _merge(res):
1524             log.msg("merging sharelists")
1525             # we merge the shares from the two sets, leaving each shnum in
1526             # its original location, but using a share from set1 or set2
1527             # according to the following sequence:
1528             #
1529             #  4-of-9  a  s2
1530             #  4-of-9  b  s2
1531             #  4-of-7  c   s3
1532             #  4-of-9  d  s2
1533             #  3-of-9  e s1
1534             #  3-of-9  f s1
1535             #  3-of-9  g s1
1536             #  4-of-9  h  s2
1537             #
1538             # so that neither form can be recovered until fetch [f], at which
1539             # point version-s1 (the 3-of-10 form) should be recoverable. If
1540             # the implementation latches on to the first version it sees,
1541             # then s2 will be recoverable at fetch [g].
1542
1543             # Later, when we implement code that handles multiple versions,
1544             # we can use this framework to assert that all recoverable
1545             # versions are retrieved, and test that 'epsilon' does its job
1546
1547             places = [2, 2, 3, 2, 1, 1, 1, 2]
1548
1549             sharemap = {}
1550
1551             for i,peerid in enumerate(self._client._peerids):
1552                 peerid_s = shortnodeid_b2a(peerid)
1553                 for shnum in self._shares1.get(peerid, {}):
1554                     if shnum < len(places):
1555                         which = places[shnum]
1556                     else:
1557                         which = "x"
1558                     self._client._storage._peers[peerid] = peers = {}
1559                     in_1 = shnum in self._shares1[peerid]
1560                     in_2 = shnum in self._shares2.get(peerid, {})
1561                     in_3 = shnum in self._shares3.get(peerid, {})
1562                     #print peerid_s, shnum, which, in_1, in_2, in_3
1563                     if which == 1:
1564                         if in_1:
1565                             peers[shnum] = self._shares1[peerid][shnum]
1566                             sharemap[shnum] = peerid
1567                     elif which == 2:
1568                         if in_2:
1569                             peers[shnum] = self._shares2[peerid][shnum]
1570                             sharemap[shnum] = peerid
1571                     elif which == 3:
1572                         if in_3:
1573                             peers[shnum] = self._shares3[peerid][shnum]
1574                             sharemap[shnum] = peerid
1575
1576             # we don't bother placing any other shares
1577             # now sort the sequence so that share 0 is returned first
1578             new_sequence = [sharemap[shnum]
1579                             for shnum in sorted(sharemap.keys())]
1580             self._client._storage._sequence = new_sequence
1581             log.msg("merge done")
1582         d.addCallback(_merge)
1583         d.addCallback(lambda res: fn3.download_best_version())
1584         def _retrieved(new_contents):
1585             # the current specified behavior is "first version recoverable"
1586             self.failUnlessEqual(new_contents, contents1)
1587         d.addCallback(_retrieved)
1588         return d
1589
1590
1591 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1592
1593     def setUp(self):
1594         return self.publish_multiple()
1595
1596     def test_multiple_versions(self):
1597         # if we see a mix of versions in the grid, download_best_version
1598         # should get the latest one
1599         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1600         d = self._fn.download_best_version()
1601         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1602         # and the checker should report problems
1603         d.addCallback(lambda res: self._fn.check(Monitor()))
1604         d.addCallback(self.check_bad, "test_multiple_versions")
1605
1606         # but if everything is at version 2, that's what we should download
1607         d.addCallback(lambda res:
1608                       self._set_versions(dict([(i,2) for i in range(10)])))
1609         d.addCallback(lambda res: self._fn.download_best_version())
1610         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1611         # if exactly one share is at version 3, we should still get v2
1612         d.addCallback(lambda res:
1613                       self._set_versions({0:3}))
1614         d.addCallback(lambda res: self._fn.download_best_version())
1615         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1616         # but the servermap should see the unrecoverable version. This
1617         # depends upon the single newer share being queried early.
1618         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1619         def _check_smap(smap):
1620             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1621             newer = smap.unrecoverable_newer_versions()
1622             self.failUnlessEqual(len(newer), 1)
1623             verinfo, health = newer.items()[0]
1624             self.failUnlessEqual(verinfo[0], 4)
1625             self.failUnlessEqual(health, (1,3))
1626             self.failIf(smap.needs_merge())
1627         d.addCallback(_check_smap)
1628         # if we have a mix of two parallel versions (s4a and s4b), we could
1629         # recover either
1630         d.addCallback(lambda res:
1631                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1632                                           1:4,3:4,5:4,7:4,9:4}))
1633         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1634         def _check_smap_mixed(smap):
1635             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1636             newer = smap.unrecoverable_newer_versions()
1637             self.failUnlessEqual(len(newer), 0)
1638             self.failUnless(smap.needs_merge())
1639         d.addCallback(_check_smap_mixed)
1640         d.addCallback(lambda res: self._fn.download_best_version())
1641         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1642                                                   res == self.CONTENTS[4]))
1643         return d
1644
1645     def test_replace(self):
1646         # if we see a mix of versions in the grid, we should be able to
1647         # replace them all with a newer version
1648
1649         # if exactly one share is at version 3, we should download (and
1650         # replace) v2, and the result should be v4. Note that the index we
1651         # give to _set_versions is different than the sequence number.
1652         target = dict([(i,2) for i in range(10)]) # seqnum3
1653         target[0] = 3 # seqnum4
1654         self._set_versions(target)
1655
1656         def _modify(oldversion):
1657             return oldversion + " modified"
1658         d = self._fn.modify(_modify)
1659         d.addCallback(lambda res: self._fn.download_best_version())
1660         expected = self.CONTENTS[2] + " modified"
1661         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1662         # and the servermap should indicate that the outlier was replaced too
1663         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1664         def _check_smap(smap):
1665             self.failUnlessEqual(smap.highest_seqnum(), 5)
1666             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1667             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1668         d.addCallback(_check_smap)
1669         return d
1670
1671
1672 class Utils(unittest.TestCase):
1673     def test_dict_of_sets(self):
1674         ds = DictOfSets()
1675         ds.add(1, "a")
1676         ds.add(2, "b")
1677         ds.add(2, "b")
1678         ds.add(2, "c")
1679         self.failUnlessEqual(ds[1], set(["a"]))
1680         self.failUnlessEqual(ds[2], set(["b", "c"]))
1681         ds.discard(3, "d") # should not raise an exception
1682         ds.discard(2, "b")
1683         self.failUnlessEqual(ds[2], set(["c"]))
1684         ds.discard(2, "c")
1685         self.failIf(2 in ds)
1686
1687     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1688         # we compare this against sets of integers
1689         x = set(range(x_start, x_start+x_length))
1690         y = set(range(y_start, y_start+y_length))
1691         should_be_inside = x.issubset(y)
1692         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1693                                                          y_start, y_length),
1694                              str((x_start, x_length, y_start, y_length)))
1695
1696     def test_cache_inside(self):
1697         c = ResponseCache()
1698         x_start = 10
1699         x_length = 5
1700         for y_start in range(8, 17):
1701             for y_length in range(8):
1702                 self._do_inside(c, x_start, x_length, y_start, y_length)
1703
1704     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1705         # we compare this against sets of integers
1706         x = set(range(x_start, x_start+x_length))
1707         y = set(range(y_start, y_start+y_length))
1708         overlap = bool(x.intersection(y))
1709         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1710                                                       y_start, y_length),
1711                              str((x_start, x_length, y_start, y_length)))
1712
1713     def test_cache_overlap(self):
1714         c = ResponseCache()
1715         x_start = 10
1716         x_length = 5
1717         for y_start in range(8, 17):
1718             for y_length in range(8):
1719                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1720
1721     def test_cache(self):
1722         c = ResponseCache()
1723         # xdata = base62.b2a(os.urandom(100))[:100]
1724         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1725         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1726         nope = (None, None)
1727         c.add("v1", 1, 0, xdata, "time0")
1728         c.add("v1", 1, 2000, ydata, "time1")
1729         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1730         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1731         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1732         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1733         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1734         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1735         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1736         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1737         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1738         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1739         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1740         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1741         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1742         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1743         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1744         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1745         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1747
1748         # optional: join fragments
1749         c = ResponseCache()
1750         c.add("v1", 1, 0, xdata[:10], "time0")
1751         c.add("v1", 1, 10, xdata[10:20], "time1")
1752         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1753
1754 class Exceptions(unittest.TestCase):
1755     def test_repr(self):
1756         nmde = NeedMoreDataError(100, 50, 100)
1757         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1758         ucwe = UncoordinatedWriteError()
1759         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1760
1761 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1762 # instances which always succeed. So we need a less-fake one.
1763
1764 class IntentionalError(Exception):
1765     pass
1766
1767 class LocalWrapper:
1768     def __init__(self, original):
1769         self.original = original
1770         self.broken = False
1771         self.post_call_notifier = None
1772     def callRemote(self, methname, *args, **kwargs):
1773         def _call():
1774             if self.broken:
1775                 raise IntentionalError("I was asked to break")
1776             meth = getattr(self.original, "remote_" + methname)
1777             return meth(*args, **kwargs)
1778         d = fireEventually()
1779         d.addCallback(lambda res: _call())
1780         if self.post_call_notifier:
1781             d.addCallback(self.post_call_notifier, methname)
1782         return d
1783
1784 class LessFakeClient(FakeClient):
1785
1786     def __init__(self, basedir, num_peers=10):
1787         self._num_peers = num_peers
1788         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1789                          for i in range(self._num_peers)]
1790         self._connections = {}
1791         for peerid in self._peerids:
1792             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1793             make_dirs(peerdir)
1794             ss = storage.StorageServer(peerdir)
1795             ss.setNodeID(peerid)
1796             lw = LocalWrapper(ss)
1797             self._connections[peerid] = lw
1798         self.nodeid = "fakenodeid"
1799
1800
1801 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1802     def test_publish_surprise(self):
1803         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1804         self.client = LessFakeClient(basedir)
1805         d = self.client.create_mutable_file("contents 1")
1806         def _created(n):
1807             d = defer.succeed(None)
1808             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1809             def _got_smap1(smap):
1810                 # stash the old state of the file
1811                 self.old_map = smap
1812             d.addCallback(_got_smap1)
1813             # then modify the file, leaving the old map untouched
1814             d.addCallback(lambda res: log.msg("starting winning write"))
1815             d.addCallback(lambda res: n.overwrite("contents 2"))
1816             # now attempt to modify the file with the old servermap. This
1817             # will look just like an uncoordinated write, in which every
1818             # single share got updated between our mapupdate and our publish
1819             d.addCallback(lambda res: log.msg("starting doomed write"))
1820             d.addCallback(lambda res:
1821                           self.shouldFail(UncoordinatedWriteError,
1822                                           "test_publish_surprise", None,
1823                                           n.upload,
1824                                           "contents 2a", self.old_map))
1825             return d
1826         d.addCallback(_created)
1827         return d
1828
1829     def test_retrieve_surprise(self):
1830         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1831         self.client = LessFakeClient(basedir)
1832         d = self.client.create_mutable_file("contents 1")
1833         def _created(n):
1834             d = defer.succeed(None)
1835             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1836             def _got_smap1(smap):
1837                 # stash the old state of the file
1838                 self.old_map = smap
1839             d.addCallback(_got_smap1)
1840             # then modify the file, leaving the old map untouched
1841             d.addCallback(lambda res: log.msg("starting winning write"))
1842             d.addCallback(lambda res: n.overwrite("contents 2"))
1843             # now attempt to retrieve the old version with the old servermap.
1844             # This will look like someone has changed the file since we
1845             # updated the servermap.
1846             d.addCallback(lambda res: n._cache._clear())
1847             d.addCallback(lambda res: log.msg("starting doomed read"))
1848             d.addCallback(lambda res:
1849                           self.shouldFail(NotEnoughSharesError,
1850                                           "test_retrieve_surprise",
1851                                           "ran out of peers: have 0 shares (k=3)",
1852                                           n.download_version,
1853                                           self.old_map,
1854                                           self.old_map.best_recoverable_version(),
1855                                           ))
1856             return d
1857         d.addCallback(_created)
1858         return d
1859
1860     def test_unexpected_shares(self):
1861         # upload the file, take a servermap, shut down one of the servers,
1862         # upload it again (causing shares to appear on a new server), then
1863         # upload using the old servermap. The last upload should fail with an
1864         # UncoordinatedWriteError, because of the shares that didn't appear
1865         # in the servermap.
1866         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1867         self.client = LessFakeClient(basedir)
1868         d = self.client.create_mutable_file("contents 1")
1869         def _created(n):
1870             d = defer.succeed(None)
1871             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1872             def _got_smap1(smap):
1873                 # stash the old state of the file
1874                 self.old_map = smap
1875                 # now shut down one of the servers
1876                 peer0 = list(smap.make_sharemap()[0])[0]
1877                 self.client._connections.pop(peer0)
1878                 # then modify the file, leaving the old map untouched
1879                 log.msg("starting winning write")
1880                 return n.overwrite("contents 2")
1881             d.addCallback(_got_smap1)
1882             # now attempt to modify the file with the old servermap. This
1883             # will look just like an uncoordinated write, in which every
1884             # single share got updated between our mapupdate and our publish
1885             d.addCallback(lambda res: log.msg("starting doomed write"))
1886             d.addCallback(lambda res:
1887                           self.shouldFail(UncoordinatedWriteError,
1888                                           "test_surprise", None,
1889                                           n.upload,
1890                                           "contents 2a", self.old_map))
1891             return d
1892         d.addCallback(_created)
1893         return d
1894
1895     def test_bad_server(self):
1896         # Break one server, then create the file: the initial publish should
1897         # complete with an alternate server. Breaking a second server should
1898         # not prevent an update from succeeding either.
1899         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1900         self.client = LessFakeClient(basedir, 20)
1901         # to make sure that one of the initial peers is broken, we have to
1902         # get creative. We create the keys, so we can figure out the storage
1903         # index, but we hold off on doing the initial publish until we've
1904         # broken the server on which the first share wants to be stored.
1905         n = FastMutableFileNode(self.client)
1906         d = defer.succeed(None)
1907         d.addCallback(n._generate_pubprivkeys)
1908         d.addCallback(n._generated)
1909         def _break_peer0(res):
1910             si = n.get_storage_index()
1911             peerlist = self.client.get_permuted_peers("storage", si)
1912             peerid0, connection0 = peerlist[0]
1913             peerid1, connection1 = peerlist[1]
1914             connection0.broken = True
1915             self.connection1 = connection1
1916         d.addCallback(_break_peer0)
1917         # now let the initial publish finally happen
1918         d.addCallback(lambda res: n._upload("contents 1", None))
1919         # that ought to work
1920         d.addCallback(lambda res: n.download_best_version())
1921         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1922         # now break the second peer
1923         def _break_peer1(res):
1924             self.connection1.broken = True
1925         d.addCallback(_break_peer1)
1926         d.addCallback(lambda res: n.overwrite("contents 2"))
1927         # that ought to work too
1928         d.addCallback(lambda res: n.download_best_version())
1929         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1930         return d
1931
1932     def test_publish_all_servers_bad(self):
1933         # Break all servers: the publish should fail
1934         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1935         self.client = LessFakeClient(basedir, 20)
1936         for connection in self.client._connections.values():
1937             connection.broken = True
1938         d = self.shouldFail(NotEnoughServersError,
1939                             "test_publish_all_servers_bad",
1940                             "Ran out of non-bad servers",
1941                             self.client.create_mutable_file, "contents")
1942         return d
1943
1944     def test_publish_no_servers(self):
1945         # no servers at all: the publish should fail
1946         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1947         self.client = LessFakeClient(basedir, 0)
1948         d = self.shouldFail(NotEnoughServersError,
1949                             "test_publish_no_servers",
1950                             "Ran out of non-bad servers",
1951                             self.client.create_mutable_file, "contents")
1952         return d
1953     test_publish_no_servers.timeout = 30
1954
1955
1956     def test_privkey_query_error(self):
1957         # when a servermap is updated with MODE_WRITE, it tries to get the
1958         # privkey. Something might go wrong during this query attempt.
1959         self.client = FakeClient(20)
1960         # we need some contents that are large enough to push the privkey out
1961         # of the early part of the file
1962         LARGE = "These are Larger contents" * 200 # about 5KB
1963         d = self.client.create_mutable_file(LARGE)
1964         def _created(n):
1965             self.uri = n.get_uri()
1966             self.n2 = self.client.create_node_from_uri(self.uri)
1967             # we start by doing a map update to figure out which is the first
1968             # server.
1969             return n.get_servermap(MODE_WRITE)
1970         d.addCallback(_created)
1971         d.addCallback(lambda res: fireEventually(res))
1972         def _got_smap1(smap):
1973             peer0 = list(smap.make_sharemap()[0])[0]
1974             # we tell the server to respond to this peer first, so that it
1975             # will be asked for the privkey first
1976             self.client._storage._sequence = [peer0]
1977             # now we make the peer fail their second query
1978             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1979         d.addCallback(_got_smap1)
1980         # now we update a servermap from a new node (which doesn't have the
1981         # privkey yet, forcing it to use a separate privkey query). Each
1982         # query response will trigger a privkey query, and since we're using
1983         # _sequence to make the peer0 response come back first, we'll send it
1984         # a privkey query first, and _sequence will again ensure that the
1985         # peer0 query will also come back before the others, and then
1986         # _special_answers will make sure that the query raises an exception.
1987         # The whole point of these hijinks is to exercise the code in
1988         # _privkey_query_failed. Note that the map-update will succeed, since
1989         # we'll just get a copy from one of the other shares.
1990         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1991         # Using FakeStorage._sequence means there will be read requests still
1992         # floating around.. wait for them to retire
1993         def _cancel_timer(res):
1994             if self.client._storage._pending_timer:
1995                 self.client._storage._pending_timer.cancel()
1996             return res
1997         d.addBoth(_cancel_timer)
1998         return d
1999
2000     def test_privkey_query_missing(self):
2001         # like test_privkey_query_error, but the shares are deleted by the
2002         # second query, instead of raising an exception.
2003         self.client = FakeClient(20)
2004         LARGE = "These are Larger contents" * 200 # about 5KB
2005         d = self.client.create_mutable_file(LARGE)
2006         def _created(n):
2007             self.uri = n.get_uri()
2008             self.n2 = self.client.create_node_from_uri(self.uri)
2009             return n.get_servermap(MODE_WRITE)
2010         d.addCallback(_created)
2011         d.addCallback(lambda res: fireEventually(res))
2012         def _got_smap1(smap):
2013             peer0 = list(smap.make_sharemap()[0])[0]
2014             self.client._storage._sequence = [peer0]
2015             self.client._storage._special_answers[peer0] = ["normal", "none"]
2016         d.addCallback(_got_smap1)
2017         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2018         def _cancel_timer(res):
2019             if self.client._storage._pending_timer:
2020                 self.client._storage._pending_timer.cancel()
2021             return res
2022         d.addBoth(_cancel_timer)
2023         return d