]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
mutable: respect the new tahoe.cfg 'shares.needed' and 'shares.total' settings
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, storage
8 from allmydata.immutable import download
9 from allmydata.util import base32, idlib
10 from allmydata.util.idlib import shortnodeid_b2a
11 from allmydata.util.hashutil import tagged_hash
12 from allmydata.util.fileutil import make_dirs
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
14      FileTooLargeError, NotEnoughSharesError, IRepairResults
15 from allmydata.monitor import Monitor
16 from allmydata.test.common import ShouldFailMixin
17 from foolscap.eventual import eventually, fireEventually
18 from foolscap.logging import log
19 import sha
20
21 from allmydata.mutable.node import MutableFileNode, BackoffAgent
22 from allmydata.mutable.common import DictOfSets, ResponseCache, \
23      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
24      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
25      NotEnoughServersError, CorruptShareError
26 from allmydata.mutable.retrieve import Retrieve
27 from allmydata.mutable.publish import Publish
28 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
29 from allmydata.mutable.layout import unpack_header, unpack_share
30 from allmydata.mutable.repair import MustForceRepairError
31
32 import common_util as testutil
33
34 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
35 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
36 # of this.
37
38 class FastMutableFileNode(MutableFileNode):
39     SIGNATURE_KEY_SIZE = 522
40
41 # this "FakeStorage" exists to put the share data in RAM and avoid using real
42 # network connections, both to speed up the tests and to reduce the amount of
43 # non-mutable.py code being exercised.
44
45 class FakeStorage:
46     # this class replaces the collection of storage servers, allowing the
47     # tests to examine and manipulate the published shares. It also lets us
48     # control the order in which read queries are answered, to exercise more
49     # of the error-handling code in Retrieve .
50     #
51     # Note that we ignore the storage index: this FakeStorage instance can
52     # only be used for a single storage index.
53
54
55     def __init__(self):
56         self._peers = {}
57         # _sequence is used to cause the responses to occur in a specific
58         # order. If it is in use, then we will defer queries instead of
59         # answering them right away, accumulating the Deferreds in a dict. We
60         # don't know exactly how many queries we'll get, so exactly one
61         # second after the first query arrives, we will release them all (in
62         # order).
63         self._sequence = None
64         self._pending = {}
65         self._pending_timer = None
66         self._special_answers = {}
67
68     def read(self, peerid, storage_index):
69         shares = self._peers.get(peerid, {})
70         if self._special_answers.get(peerid, []):
71             mode = self._special_answers[peerid].pop(0)
72             if mode == "fail":
73                 shares = failure.Failure(IntentionalError())
74             elif mode == "none":
75                 shares = {}
76             elif mode == "normal":
77                 pass
78         if self._sequence is None:
79             return defer.succeed(shares)
80         d = defer.Deferred()
81         if not self._pending:
82             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
83         self._pending[peerid] = (d, shares)
84         return d
85
86     def _fire_readers(self):
87         self._pending_timer = None
88         pending = self._pending
89         self._pending = {}
90         extra = []
91         for peerid in self._sequence:
92             if peerid in pending:
93                 d, shares = pending.pop(peerid)
94                 eventually(d.callback, shares)
95         for (d, shares) in pending.values():
96             eventually(d.callback, shares)
97
98     def write(self, peerid, storage_index, shnum, offset, data):
99         if peerid not in self._peers:
100             self._peers[peerid] = {}
101         shares = self._peers[peerid]
102         f = StringIO()
103         f.write(shares.get(shnum, ""))
104         f.seek(offset)
105         f.write(data)
106         shares[shnum] = f.getvalue()
107
108
109 class FakeStorageServer:
110     def __init__(self, peerid, storage):
111         self.peerid = peerid
112         self.storage = storage
113         self.queries = 0
114     def callRemote(self, methname, *args, **kwargs):
115         def _call():
116             meth = getattr(self, methname)
117             return meth(*args, **kwargs)
118         d = fireEventually()
119         d.addCallback(lambda res: _call())
120         return d
121     def callRemoteOnly(self, methname, *args, **kwargs):
122         d = self.callRemote(methname, *args, **kwargs)
123         d.addBoth(lambda ignore: None)
124         pass
125
126     def advise_corrupt_share(self, share_type, storage_index, shnum, reason):
127         pass
128
129     def slot_readv(self, storage_index, shnums, readv):
130         d = self.storage.read(self.peerid, storage_index)
131         def _read(shares):
132             response = {}
133             for shnum in shares:
134                 if shnums and shnum not in shnums:
135                     continue
136                 vector = response[shnum] = []
137                 for (offset, length) in readv:
138                     assert isinstance(offset, (int, long)), offset
139                     assert isinstance(length, (int, long)), length
140                     vector.append(shares[shnum][offset:offset+length])
141             return response
142         d.addCallback(_read)
143         return d
144
145     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
146                                         tw_vectors, read_vector):
147         # always-pass: parrot the test vectors back to them.
148         readv = {}
149         for shnum, (testv, writev, new_length) in tw_vectors.items():
150             for (offset, length, op, specimen) in testv:
151                 assert op in ("le", "eq", "ge")
152             # TODO: this isn't right, the read is controlled by read_vector,
153             # not by testv
154             readv[shnum] = [ specimen
155                              for (offset, length, op, specimen)
156                              in testv ]
157             for (offset, data) in writev:
158                 self.storage.write(self.peerid, storage_index, shnum,
159                                    offset, data)
160         answer = (True, readv)
161         return fireEventually(answer)
162
163
164 # our "FakeClient" has just enough functionality of the real Client to let
165 # the tests run.
166
167 class FakeClient:
168     mutable_file_node_class = FastMutableFileNode
169
170     def __init__(self, num_peers=10):
171         self._storage = FakeStorage()
172         self._num_peers = num_peers
173         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
174                          for i in range(self._num_peers)]
175         self._connections = dict([(peerid, FakeStorageServer(peerid,
176                                                              self._storage))
177                                   for peerid in self._peerids])
178         self.nodeid = "fakenodeid"
179
180     def get_encoding_parameters(self):
181         return {"k": 3, "n": 10}
182
183     def log(self, msg, **kw):
184         return log.msg(msg, **kw)
185
186     def get_renewal_secret(self):
187         return "I hereby permit you to renew my files"
188     def get_cancel_secret(self):
189         return "I hereby permit you to cancel my leases"
190
191     def create_mutable_file(self, contents=""):
192         n = self.mutable_file_node_class(self)
193         d = n.create(contents)
194         d.addCallback(lambda res: n)
195         return d
196
197     def notify_retrieve(self, r):
198         pass
199     def notify_publish(self, p, size):
200         pass
201     def notify_mapupdate(self, u):
202         pass
203
204     def create_node_from_uri(self, u):
205         u = IURI(u)
206         assert IMutableFileURI.providedBy(u), u
207         res = self.mutable_file_node_class(self).init_from_uri(u)
208         return res
209
210     def get_permuted_peers(self, service_name, key):
211         """
212         @return: list of (peerid, connection,)
213         """
214         results = []
215         for (peerid, connection) in self._connections.items():
216             assert isinstance(peerid, str)
217             permuted = sha.new(key + peerid).digest()
218             results.append((permuted, peerid, connection))
219         results.sort()
220         results = [ (r[1],r[2]) for r in results]
221         return results
222
223     def upload(self, uploadable):
224         assert IUploadable.providedBy(uploadable)
225         d = uploadable.get_size()
226         d.addCallback(lambda length: uploadable.read(length))
227         #d.addCallback(self.create_mutable_file)
228         def _got_data(datav):
229             data = "".join(datav)
230             #newnode = FastMutableFileNode(self)
231             return uri.LiteralFileURI(data)
232         d.addCallback(_got_data)
233         return d
234
235
236 def flip_bit(original, byte_offset):
237     return (original[:byte_offset] +
238             chr(ord(original[byte_offset]) ^ 0x01) +
239             original[byte_offset+1:])
240
241 def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
242     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
243     # list of shnums to corrupt.
244     for peerid in s._peers:
245         shares = s._peers[peerid]
246         for shnum in shares:
247             if (shnums_to_corrupt is not None
248                 and shnum not in shnums_to_corrupt):
249                 continue
250             data = shares[shnum]
251             (version,
252              seqnum,
253              root_hash,
254              IV,
255              k, N, segsize, datalen,
256              o) = unpack_header(data)
257             if isinstance(offset, tuple):
258                 offset1, offset2 = offset
259             else:
260                 offset1 = offset
261                 offset2 = 0
262             if offset1 == "pubkey":
263                 real_offset = 107
264             elif offset1 in o:
265                 real_offset = o[offset1]
266             else:
267                 real_offset = offset1
268             real_offset = int(real_offset) + offset2 + offset_offset
269             assert isinstance(real_offset, int), offset
270             shares[shnum] = flip_bit(data, real_offset)
271     return res
272
273 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
274     def setUp(self):
275         self.client = FakeClient()
276
277     def test_create(self):
278         d = self.client.create_mutable_file()
279         def _created(n):
280             self.failUnless(isinstance(n, FastMutableFileNode))
281             self.failUnlessEqual(n.get_storage_index(), n._storage_index)
282             peer0 = self.client._peerids[0]
283             shnums = self.client._storage._peers[peer0].keys()
284             self.failUnlessEqual(len(shnums), 1)
285         d.addCallback(_created)
286         return d
287
288     def test_serialize(self):
289         n = MutableFileNode(self.client)
290         calls = []
291         def _callback(*args, **kwargs):
292             self.failUnlessEqual(args, (4,) )
293             self.failUnlessEqual(kwargs, {"foo": 5})
294             calls.append(1)
295             return 6
296         d = n._do_serialized(_callback, 4, foo=5)
297         def _check_callback(res):
298             self.failUnlessEqual(res, 6)
299             self.failUnlessEqual(calls, [1])
300         d.addCallback(_check_callback)
301
302         def _errback():
303             raise ValueError("heya")
304         d.addCallback(lambda res:
305                       self.shouldFail(ValueError, "_check_errback", "heya",
306                                       n._do_serialized, _errback))
307         return d
308
309     def test_upload_and_download(self):
310         d = self.client.create_mutable_file()
311         def _created(n):
312             d = defer.succeed(None)
313             d.addCallback(lambda res: n.get_servermap(MODE_READ))
314             d.addCallback(lambda smap: smap.dump(StringIO()))
315             d.addCallback(lambda sio:
316                           self.failUnless("3-of-10" in sio.getvalue()))
317             d.addCallback(lambda res: n.overwrite("contents 1"))
318             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
319             d.addCallback(lambda res: n.download_best_version())
320             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
321             d.addCallback(lambda res: n.get_size_of_best_version())
322             d.addCallback(lambda size:
323                           self.failUnlessEqual(size, len("contents 1")))
324             d.addCallback(lambda res: n.overwrite("contents 2"))
325             d.addCallback(lambda res: n.download_best_version())
326             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
327             d.addCallback(lambda res: n.download(download.Data()))
328             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
329             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
330             d.addCallback(lambda smap: n.upload("contents 3", smap))
331             d.addCallback(lambda res: n.download_best_version())
332             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
333             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
334             d.addCallback(lambda smap:
335                           n.download_version(smap,
336                                              smap.best_recoverable_version()))
337             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
338             # test a file that is large enough to overcome the
339             # mapupdate-to-retrieve data caching (i.e. make the shares larger
340             # than the default readsize, which is 2000 bytes). A 15kB file
341             # will have 5kB shares.
342             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
343             d.addCallback(lambda res: n.download_best_version())
344             d.addCallback(lambda res:
345                           self.failUnlessEqual(res, "large size file" * 1000))
346             return d
347         d.addCallback(_created)
348         return d
349
350     def test_create_with_initial_contents(self):
351         d = self.client.create_mutable_file("contents 1")
352         def _created(n):
353             d = n.download_best_version()
354             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
355             d.addCallback(lambda res: n.overwrite("contents 2"))
356             d.addCallback(lambda res: n.download_best_version())
357             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
358             return d
359         d.addCallback(_created)
360         return d
361
362     def test_create_with_too_large_contents(self):
363         BIG = "a" * (Publish.MAX_SEGMENT_SIZE+1)
364         d = self.shouldFail(FileTooLargeError, "too_large",
365                             "SDMF is limited to one segment, and %d > %d" %
366                             (len(BIG), Publish.MAX_SEGMENT_SIZE),
367                             self.client.create_mutable_file, BIG)
368         d.addCallback(lambda res: self.client.create_mutable_file("small"))
369         def _created(n):
370             return self.shouldFail(FileTooLargeError, "too_large_2",
371                                    "SDMF is limited to one segment, and %d > %d" %
372                                    (len(BIG), Publish.MAX_SEGMENT_SIZE),
373                                    n.overwrite, BIG)
374         d.addCallback(_created)
375         return d
376
377     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
378         d = n.get_servermap(MODE_READ)
379         d.addCallback(lambda servermap: servermap.best_recoverable_version())
380         d.addCallback(lambda verinfo:
381                       self.failUnlessEqual(verinfo[0], expected_seqnum))
382         return d
383
384     def test_modify(self):
385         def _modifier(old_contents):
386             return old_contents + "line2"
387         def _non_modifier(old_contents):
388             return old_contents
389         def _none_modifier(old_contents):
390             return None
391         def _error_modifier(old_contents):
392             raise ValueError("oops")
393         def _toobig_modifier(old_contents):
394             return "b" * (Publish.MAX_SEGMENT_SIZE+1)
395         calls = []
396         def _ucw_error_modifier(old_contents):
397             # simulate an UncoordinatedWriteError once
398             calls.append(1)
399             if len(calls) <= 1:
400                 raise UncoordinatedWriteError("simulated")
401             return old_contents + "line3"
402
403         d = self.client.create_mutable_file("line1")
404         def _created(n):
405             d = n.modify(_modifier)
406             d.addCallback(lambda res: n.download_best_version())
407             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
409
410             d.addCallback(lambda res: n.modify(_non_modifier))
411             d.addCallback(lambda res: n.download_best_version())
412             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
413             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
414
415             d.addCallback(lambda res: n.modify(_none_modifier))
416             d.addCallback(lambda res: n.download_best_version())
417             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
418             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
419
420             d.addCallback(lambda res:
421                           self.shouldFail(ValueError, "error_modifier", None,
422                                           n.modify, _error_modifier))
423             d.addCallback(lambda res: n.download_best_version())
424             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
425             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
426
427             d.addCallback(lambda res:
428                           self.shouldFail(FileTooLargeError, "toobig_modifier",
429                                           "SDMF is limited to one segment",
430                                           n.modify, _toobig_modifier))
431             d.addCallback(lambda res: n.download_best_version())
432             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
433             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
434
435             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
436             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
437             d.addCallback(lambda res: n.download_best_version())
438             d.addCallback(lambda res: self.failUnlessEqual(res,
439                                                            "line1line2line3"))
440             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
441
442             return d
443         d.addCallback(_created)
444         return d
445
446     def test_modify_backoffer(self):
447         def _modifier(old_contents):
448             return old_contents + "line2"
449         calls = []
450         def _ucw_error_modifier(old_contents):
451             # simulate an UncoordinatedWriteError once
452             calls.append(1)
453             if len(calls) <= 1:
454                 raise UncoordinatedWriteError("simulated")
455             return old_contents + "line3"
456         def _always_ucw_error_modifier(old_contents):
457             raise UncoordinatedWriteError("simulated")
458         def _backoff_stopper(node, f):
459             return f
460         def _backoff_pauser(node, f):
461             d = defer.Deferred()
462             reactor.callLater(0.5, d.callback, None)
463             return d
464
465         # the give-up-er will hit its maximum retry count quickly
466         giveuper = BackoffAgent()
467         giveuper._delay = 0.1
468         giveuper.factor = 1
469
470         d = self.client.create_mutable_file("line1")
471         def _created(n):
472             d = n.modify(_modifier)
473             d.addCallback(lambda res: n.download_best_version())
474             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
475             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
476
477             d.addCallback(lambda res:
478                           self.shouldFail(UncoordinatedWriteError,
479                                           "_backoff_stopper", None,
480                                           n.modify, _ucw_error_modifier,
481                                           _backoff_stopper))
482             d.addCallback(lambda res: n.download_best_version())
483             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
484             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
485
486             def _reset_ucw_error_modifier(res):
487                 calls[:] = []
488                 return res
489             d.addCallback(_reset_ucw_error_modifier)
490             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
491                                                _backoff_pauser))
492             d.addCallback(lambda res: n.download_best_version())
493             d.addCallback(lambda res: self.failUnlessEqual(res,
494                                                            "line1line2line3"))
495             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
496
497             d.addCallback(lambda res:
498                           self.shouldFail(UncoordinatedWriteError,
499                                           "giveuper", None,
500                                           n.modify, _always_ucw_error_modifier,
501                                           giveuper.delay))
502             d.addCallback(lambda res: n.download_best_version())
503             d.addCallback(lambda res: self.failUnlessEqual(res,
504                                                            "line1line2line3"))
505             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
506
507             return d
508         d.addCallback(_created)
509         return d
510
511     def test_upload_and_download_full_size_keys(self):
512         self.client.mutable_file_node_class = MutableFileNode
513         d = self.client.create_mutable_file()
514         def _created(n):
515             d = defer.succeed(None)
516             d.addCallback(lambda res: n.get_servermap(MODE_READ))
517             d.addCallback(lambda smap: smap.dump(StringIO()))
518             d.addCallback(lambda sio:
519                           self.failUnless("3-of-10" in sio.getvalue()))
520             d.addCallback(lambda res: n.overwrite("contents 1"))
521             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
522             d.addCallback(lambda res: n.download_best_version())
523             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
524             d.addCallback(lambda res: n.overwrite("contents 2"))
525             d.addCallback(lambda res: n.download_best_version())
526             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
527             d.addCallback(lambda res: n.download(download.Data()))
528             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
529             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
530             d.addCallback(lambda smap: n.upload("contents 3", smap))
531             d.addCallback(lambda res: n.download_best_version())
532             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
533             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
534             d.addCallback(lambda smap:
535                           n.download_version(smap,
536                                              smap.best_recoverable_version()))
537             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
538             return d
539         d.addCallback(_created)
540         return d
541
542
543 class MakeShares(unittest.TestCase):
544     def test_encrypt(self):
545         c = FakeClient()
546         fn = FastMutableFileNode(c)
547         CONTENTS = "some initial contents"
548         d = fn.create(CONTENTS)
549         def _created(res):
550             p = Publish(fn, None)
551             p.salt = "SALT" * 4
552             p.readkey = "\x00" * 16
553             p.newdata = CONTENTS
554             p.required_shares = 3
555             p.total_shares = 10
556             p.setup_encoding_parameters()
557             return p._encrypt_and_encode()
558         d.addCallback(_created)
559         def _done(shares_and_shareids):
560             (shares, share_ids) = shares_and_shareids
561             self.failUnlessEqual(len(shares), 10)
562             for sh in shares:
563                 self.failUnless(isinstance(sh, str))
564                 self.failUnlessEqual(len(sh), 7)
565             self.failUnlessEqual(len(share_ids), 10)
566         d.addCallback(_done)
567         return d
568
569     def test_generate(self):
570         c = FakeClient()
571         fn = FastMutableFileNode(c)
572         CONTENTS = "some initial contents"
573         d = fn.create(CONTENTS)
574         def _created(res):
575             p = Publish(fn, None)
576             self._p = p
577             p.newdata = CONTENTS
578             p.required_shares = 3
579             p.total_shares = 10
580             p.setup_encoding_parameters()
581             p._new_seqnum = 3
582             p.salt = "SALT" * 4
583             # make some fake shares
584             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
585             p._privkey = fn.get_privkey()
586             p._encprivkey = fn.get_encprivkey()
587             p._pubkey = fn.get_pubkey()
588             return p._generate_shares(shares_and_ids)
589         d.addCallback(_created)
590         def _generated(res):
591             p = self._p
592             final_shares = p.shares
593             root_hash = p.root_hash
594             self.failUnlessEqual(len(root_hash), 32)
595             self.failUnless(isinstance(final_shares, dict))
596             self.failUnlessEqual(len(final_shares), 10)
597             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
598             for i,sh in final_shares.items():
599                 self.failUnless(isinstance(sh, str))
600                 # feed the share through the unpacker as a sanity-check
601                 pieces = unpack_share(sh)
602                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
603                  pubkey, signature, share_hash_chain, block_hash_tree,
604                  share_data, enc_privkey) = pieces
605                 self.failUnlessEqual(u_seqnum, 3)
606                 self.failUnlessEqual(u_root_hash, root_hash)
607                 self.failUnlessEqual(k, 3)
608                 self.failUnlessEqual(N, 10)
609                 self.failUnlessEqual(segsize, 21)
610                 self.failUnlessEqual(datalen, len(CONTENTS))
611                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
612                 sig_material = struct.pack(">BQ32s16s BBQQ",
613                                            0, p._new_seqnum, root_hash, IV,
614                                            k, N, segsize, datalen)
615                 self.failUnless(p._pubkey.verify(sig_material, signature))
616                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
617                 self.failUnless(isinstance(share_hash_chain, dict))
618                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
619                 for shnum,share_hash in share_hash_chain.items():
620                     self.failUnless(isinstance(shnum, int))
621                     self.failUnless(isinstance(share_hash, str))
622                     self.failUnlessEqual(len(share_hash), 32)
623                 self.failUnless(isinstance(block_hash_tree, list))
624                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
625                 self.failUnlessEqual(IV, "SALT"*4)
626                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
627                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
628         d.addCallback(_generated)
629         return d
630
631     # TODO: when we publish to 20 peers, we should get one share per peer on 10
632     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
633     # when we publish to zero peers, we should get a NotEnoughSharesError
634
635 class PublishMixin:
636     def publish_one(self):
637         # publish a file and create shares, which can then be manipulated
638         # later.
639         self.CONTENTS = "New contents go here" * 1000
640         num_peers = 20
641         self._client = FakeClient(num_peers)
642         self._storage = self._client._storage
643         d = self._client.create_mutable_file(self.CONTENTS)
644         def _created(node):
645             self._fn = node
646             self._fn2 = self._client.create_node_from_uri(node.get_uri())
647         d.addCallback(_created)
648         return d
649     def publish_multiple(self):
650         self.CONTENTS = ["Contents 0",
651                          "Contents 1",
652                          "Contents 2",
653                          "Contents 3a",
654                          "Contents 3b"]
655         self._copied_shares = {}
656         num_peers = 20
657         self._client = FakeClient(num_peers)
658         self._storage = self._client._storage
659         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
660         def _created(node):
661             self._fn = node
662             # now create multiple versions of the same file, and accumulate
663             # their shares, so we can mix and match them later.
664             d = defer.succeed(None)
665             d.addCallback(self._copy_shares, 0)
666             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
667             d.addCallback(self._copy_shares, 1)
668             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
669             d.addCallback(self._copy_shares, 2)
670             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
671             d.addCallback(self._copy_shares, 3)
672             # now we replace all the shares with version s3, and upload a new
673             # version to get s4b.
674             rollback = dict([(i,2) for i in range(10)])
675             d.addCallback(lambda res: self._set_versions(rollback))
676             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
677             d.addCallback(self._copy_shares, 4)
678             # we leave the storage in state 4
679             return d
680         d.addCallback(_created)
681         return d
682
683     def _copy_shares(self, ignored, index):
684         shares = self._client._storage._peers
685         # we need a deep copy
686         new_shares = {}
687         for peerid in shares:
688             new_shares[peerid] = {}
689             for shnum in shares[peerid]:
690                 new_shares[peerid][shnum] = shares[peerid][shnum]
691         self._copied_shares[index] = new_shares
692
693     def _set_versions(self, versionmap):
694         # versionmap maps shnums to which version (0,1,2,3,4) we want the
695         # share to be at. Any shnum which is left out of the map will stay at
696         # its current version.
697         shares = self._client._storage._peers
698         oldshares = self._copied_shares
699         for peerid in shares:
700             for shnum in shares[peerid]:
701                 if shnum in versionmap:
702                     index = versionmap[shnum]
703                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
704
705
706 class Servermap(unittest.TestCase, PublishMixin):
707     def setUp(self):
708         return self.publish_one()
709
710     def make_servermap(self, mode=MODE_CHECK, fn=None):
711         if fn is None:
712             fn = self._fn
713         smu = ServermapUpdater(fn, Monitor(), ServerMap(), mode)
714         d = smu.update()
715         return d
716
717     def update_servermap(self, oldmap, mode=MODE_CHECK):
718         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
719         d = smu.update()
720         return d
721
722     def failUnlessOneRecoverable(self, sm, num_shares):
723         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
724         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
725         best = sm.best_recoverable_version()
726         self.failIfEqual(best, None)
727         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
728         self.failUnlessEqual(len(sm.shares_available()), 1)
729         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10))
730         shnum, peerids = sm.make_sharemap().items()[0]
731         peerid = list(peerids)[0]
732         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
733         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
734         return sm
735
736     def test_basic(self):
737         d = defer.succeed(None)
738         ms = self.make_servermap
739         us = self.update_servermap
740
741         d.addCallback(lambda res: ms(mode=MODE_CHECK))
742         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
743         d.addCallback(lambda res: ms(mode=MODE_WRITE))
744         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
745         d.addCallback(lambda res: ms(mode=MODE_READ))
746         # this more stops at k+epsilon, and epsilon=k, so 6 shares
747         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
748         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
749         # this mode stops at 'k' shares
750         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
751
752         # and can we re-use the same servermap? Note that these are sorted in
753         # increasing order of number of servers queried, since once a server
754         # gets into the servermap, we'll always ask it for an update.
755         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
756         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
757         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
758         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
759         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
760         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
761         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
762         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
763         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
764
765         return d
766
767     def test_fetch_privkey(self):
768         d = defer.succeed(None)
769         # use the sibling filenode (which hasn't been used yet), and make
770         # sure it can fetch the privkey. The file is small, so the privkey
771         # will be fetched on the first (query) pass.
772         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
773         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
774
775         # create a new file, which is large enough to knock the privkey out
776         # of the early part of the file
777         LARGE = "These are Larger contents" * 200 # about 5KB
778         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
779         def _created(large_fn):
780             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
781             return self.make_servermap(MODE_WRITE, large_fn2)
782         d.addCallback(_created)
783         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
784         return d
785
786     def test_mark_bad(self):
787         d = defer.succeed(None)
788         ms = self.make_servermap
789         us = self.update_servermap
790
791         d.addCallback(lambda res: ms(mode=MODE_READ))
792         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
793         def _made_map(sm):
794             v = sm.best_recoverable_version()
795             vm = sm.make_versionmap()
796             shares = list(vm[v])
797             self.failUnlessEqual(len(shares), 6)
798             self._corrupted = set()
799             # mark the first 5 shares as corrupt, then update the servermap.
800             # The map should not have the marked shares it in any more, and
801             # new shares should be found to replace the missing ones.
802             for (shnum, peerid, timestamp) in shares:
803                 if shnum < 5:
804                     self._corrupted.add( (peerid, shnum) )
805                     sm.mark_bad_share(peerid, shnum, "")
806             return self.update_servermap(sm, MODE_WRITE)
807         d.addCallback(_made_map)
808         def _check_map(sm):
809             # this should find all 5 shares that weren't marked bad
810             v = sm.best_recoverable_version()
811             vm = sm.make_versionmap()
812             shares = list(vm[v])
813             for (peerid, shnum) in self._corrupted:
814                 peer_shares = sm.shares_on_peer(peerid)
815                 self.failIf(shnum in peer_shares,
816                             "%d was in %s" % (shnum, peer_shares))
817             self.failUnlessEqual(len(shares), 5)
818         d.addCallback(_check_map)
819         return d
820
821     def failUnlessNoneRecoverable(self, sm):
822         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
823         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
824         best = sm.best_recoverable_version()
825         self.failUnlessEqual(best, None)
826         self.failUnlessEqual(len(sm.shares_available()), 0)
827
828     def test_no_shares(self):
829         self._client._storage._peers = {} # delete all shares
830         ms = self.make_servermap
831         d = defer.succeed(None)
832
833         d.addCallback(lambda res: ms(mode=MODE_CHECK))
834         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
835
836         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
837         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
838
839         d.addCallback(lambda res: ms(mode=MODE_WRITE))
840         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
841
842         d.addCallback(lambda res: ms(mode=MODE_READ))
843         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
844
845         return d
846
847     def failUnlessNotQuiteEnough(self, sm):
848         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
849         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
850         best = sm.best_recoverable_version()
851         self.failUnlessEqual(best, None)
852         self.failUnlessEqual(len(sm.shares_available()), 1)
853         self.failUnlessEqual(sm.shares_available().values()[0], (2,3,10) )
854         return sm
855
856     def test_not_quite_enough_shares(self):
857         s = self._client._storage
858         ms = self.make_servermap
859         num_shares = len(s._peers)
860         for peerid in s._peers:
861             s._peers[peerid] = {}
862             num_shares -= 1
863             if num_shares == 2:
864                 break
865         # now there ought to be only two shares left
866         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
867
868         d = defer.succeed(None)
869
870         d.addCallback(lambda res: ms(mode=MODE_CHECK))
871         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
872         d.addCallback(lambda sm:
873                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
874         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
875         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
876         d.addCallback(lambda res: ms(mode=MODE_WRITE))
877         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
878         d.addCallback(lambda res: ms(mode=MODE_READ))
879         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
880
881         return d
882
883
884
885 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
886     def setUp(self):
887         return self.publish_one()
888
889     def make_servermap(self, mode=MODE_READ, oldmap=None):
890         if oldmap is None:
891             oldmap = ServerMap()
892         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
893         d = smu.update()
894         return d
895
896     def abbrev_verinfo(self, verinfo):
897         if verinfo is None:
898             return None
899         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
900          offsets_tuple) = verinfo
901         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
902
903     def abbrev_verinfo_dict(self, verinfo_d):
904         output = {}
905         for verinfo,value in verinfo_d.items():
906             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
907              offsets_tuple) = verinfo
908             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
909         return output
910
911     def dump_servermap(self, servermap):
912         print "SERVERMAP", servermap
913         print "RECOVERABLE", [self.abbrev_verinfo(v)
914                               for v in servermap.recoverable_versions()]
915         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
916         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
917
918     def do_download(self, servermap, version=None):
919         if version is None:
920             version = servermap.best_recoverable_version()
921         r = Retrieve(self._fn, servermap, version)
922         return r.download()
923
924     def test_basic(self):
925         d = self.make_servermap()
926         def _do_retrieve(servermap):
927             self._smap = servermap
928             #self.dump_servermap(servermap)
929             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
930             return self.do_download(servermap)
931         d.addCallback(_do_retrieve)
932         def _retrieved(new_contents):
933             self.failUnlessEqual(new_contents, self.CONTENTS)
934         d.addCallback(_retrieved)
935         # we should be able to re-use the same servermap, both with and
936         # without updating it.
937         d.addCallback(lambda res: self.do_download(self._smap))
938         d.addCallback(_retrieved)
939         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
940         d.addCallback(lambda res: self.do_download(self._smap))
941         d.addCallback(_retrieved)
942         # clobbering the pubkey should make the servermap updater re-fetch it
943         def _clobber_pubkey(res):
944             self._fn._pubkey = None
945         d.addCallback(_clobber_pubkey)
946         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
947         d.addCallback(lambda res: self.do_download(self._smap))
948         d.addCallback(_retrieved)
949         return d
950
951     def test_all_shares_vanished(self):
952         d = self.make_servermap()
953         def _remove_shares(servermap):
954             for shares in self._storage._peers.values():
955                 shares.clear()
956             d1 = self.shouldFail(NotEnoughSharesError,
957                                  "test_all_shares_vanished",
958                                  "ran out of peers",
959                                  self.do_download, servermap)
960             return d1
961         d.addCallback(_remove_shares)
962         return d
963
964     def test_no_servers(self):
965         c2 = FakeClient(0)
966         self._fn._client = c2
967         # if there are no servers, then a MODE_READ servermap should come
968         # back empty
969         d = self.make_servermap()
970         def _check_servermap(servermap):
971             self.failUnlessEqual(servermap.best_recoverable_version(), None)
972             self.failIf(servermap.recoverable_versions())
973             self.failIf(servermap.unrecoverable_versions())
974             self.failIf(servermap.all_peers())
975         d.addCallback(_check_servermap)
976         return d
977     test_no_servers.timeout = 15
978
979     def test_no_servers_download(self):
980         c2 = FakeClient(0)
981         self._fn._client = c2
982         d = self.shouldFail(UnrecoverableFileError,
983                             "test_no_servers_download",
984                             "no recoverable versions",
985                             self._fn.download_best_version)
986         def _restore(res):
987             # a failed download that occurs while we aren't connected to
988             # anybody should not prevent a subsequent download from working.
989             # This isn't quite the webapi-driven test that #463 wants, but it
990             # should be close enough.
991             self._fn._client = self._client
992             return self._fn.download_best_version()
993         def _retrieved(new_contents):
994             self.failUnlessEqual(new_contents, self.CONTENTS)
995         d.addCallback(_restore)
996         d.addCallback(_retrieved)
997         return d
998     test_no_servers_download.timeout = 15
999
1000     def _test_corrupt_all(self, offset, substring,
1001                           should_succeed=False, corrupt_early=True,
1002                           failure_checker=None):
1003         d = defer.succeed(None)
1004         if corrupt_early:
1005             d.addCallback(corrupt, self._storage, offset)
1006         d.addCallback(lambda res: self.make_servermap())
1007         if not corrupt_early:
1008             d.addCallback(corrupt, self._storage, offset)
1009         def _do_retrieve(servermap):
1010             ver = servermap.best_recoverable_version()
1011             if ver is None and not should_succeed:
1012                 # no recoverable versions == not succeeding. The problem
1013                 # should be noted in the servermap's list of problems.
1014                 if substring:
1015                     allproblems = [str(f) for f in servermap.problems]
1016                     self.failUnless(substring in "".join(allproblems))
1017                 return servermap
1018             if should_succeed:
1019                 d1 = self._fn.download_version(servermap, ver)
1020                 d1.addCallback(lambda new_contents:
1021                                self.failUnlessEqual(new_contents, self.CONTENTS))
1022             else:
1023                 d1 = self.shouldFail(NotEnoughSharesError,
1024                                      "_corrupt_all(offset=%s)" % (offset,),
1025                                      substring,
1026                                      self._fn.download_version, servermap, ver)
1027             if failure_checker:
1028                 d1.addCallback(failure_checker)
1029             d1.addCallback(lambda res: servermap)
1030             return d1
1031         d.addCallback(_do_retrieve)
1032         return d
1033
1034     def test_corrupt_all_verbyte(self):
1035         # when the version byte is not 0, we hit an assertion error in
1036         # unpack_share().
1037         d = self._test_corrupt_all(0, "AssertionError")
1038         def _check_servermap(servermap):
1039             # and the dump should mention the problems
1040             s = StringIO()
1041             dump = servermap.dump(s).getvalue()
1042             self.failUnless("10 PROBLEMS" in dump, dump)
1043         d.addCallback(_check_servermap)
1044         return d
1045
1046     def test_corrupt_all_seqnum(self):
1047         # a corrupt sequence number will trigger a bad signature
1048         return self._test_corrupt_all(1, "signature is invalid")
1049
1050     def test_corrupt_all_R(self):
1051         # a corrupt root hash will trigger a bad signature
1052         return self._test_corrupt_all(9, "signature is invalid")
1053
1054     def test_corrupt_all_IV(self):
1055         # a corrupt salt/IV will trigger a bad signature
1056         return self._test_corrupt_all(41, "signature is invalid")
1057
1058     def test_corrupt_all_k(self):
1059         # a corrupt 'k' will trigger a bad signature
1060         return self._test_corrupt_all(57, "signature is invalid")
1061
1062     def test_corrupt_all_N(self):
1063         # a corrupt 'N' will trigger a bad signature
1064         return self._test_corrupt_all(58, "signature is invalid")
1065
1066     def test_corrupt_all_segsize(self):
1067         # a corrupt segsize will trigger a bad signature
1068         return self._test_corrupt_all(59, "signature is invalid")
1069
1070     def test_corrupt_all_datalen(self):
1071         # a corrupt data length will trigger a bad signature
1072         return self._test_corrupt_all(67, "signature is invalid")
1073
1074     def test_corrupt_all_pubkey(self):
1075         # a corrupt pubkey won't match the URI's fingerprint. We need to
1076         # remove the pubkey from the filenode, or else it won't bother trying
1077         # to update it.
1078         self._fn._pubkey = None
1079         return self._test_corrupt_all("pubkey",
1080                                       "pubkey doesn't match fingerprint")
1081
1082     def test_corrupt_all_sig(self):
1083         # a corrupt signature is a bad one
1084         # the signature runs from about [543:799], depending upon the length
1085         # of the pubkey
1086         return self._test_corrupt_all("signature", "signature is invalid")
1087
1088     def test_corrupt_all_share_hash_chain_number(self):
1089         # a corrupt share hash chain entry will show up as a bad hash. If we
1090         # mangle the first byte, that will look like a bad hash number,
1091         # causing an IndexError
1092         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
1093
1094     def test_corrupt_all_share_hash_chain_hash(self):
1095         # a corrupt share hash chain entry will show up as a bad hash. If we
1096         # mangle a few bytes in, that will look like a bad hash.
1097         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
1098
1099     def test_corrupt_all_block_hash_tree(self):
1100         return self._test_corrupt_all("block_hash_tree",
1101                                       "block hash tree failure")
1102
1103     def test_corrupt_all_block(self):
1104         return self._test_corrupt_all("share_data", "block hash tree failure")
1105
1106     def test_corrupt_all_encprivkey(self):
1107         # a corrupted privkey won't even be noticed by the reader, only by a
1108         # writer.
1109         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
1110
1111
1112     def test_corrupt_all_seqnum_late(self):
1113         # corrupting the seqnum between mapupdate and retrieve should result
1114         # in NotEnoughSharesError, since each share will look invalid
1115         def _check(res):
1116             f = res[0]
1117             self.failUnless(f.check(NotEnoughSharesError))
1118             self.failUnless("someone wrote to the data since we read the servermap" in str(f))
1119         return self._test_corrupt_all(1, "ran out of peers",
1120                                       corrupt_early=False,
1121                                       failure_checker=_check)
1122
1123     def test_corrupt_all_block_hash_tree_late(self):
1124         def _check(res):
1125             f = res[0]
1126             self.failUnless(f.check(NotEnoughSharesError))
1127         return self._test_corrupt_all("block_hash_tree",
1128                                       "block hash tree failure",
1129                                       corrupt_early=False,
1130                                       failure_checker=_check)
1131
1132
1133     def test_corrupt_all_block_late(self):
1134         def _check(res):
1135             f = res[0]
1136             self.failUnless(f.check(NotEnoughSharesError))
1137         return self._test_corrupt_all("share_data", "block hash tree failure",
1138                                       corrupt_early=False,
1139                                       failure_checker=_check)
1140
1141
1142     def test_basic_pubkey_at_end(self):
1143         # we corrupt the pubkey in all but the last 'k' shares, allowing the
1144         # download to succeed but forcing a bunch of retries first. Note that
1145         # this is rather pessimistic: our Retrieve process will throw away
1146         # the whole share if the pubkey is bad, even though the rest of the
1147         # share might be good.
1148
1149         self._fn._pubkey = None
1150         k = self._fn.get_required_shares()
1151         N = self._fn.get_total_shares()
1152         d = defer.succeed(None)
1153         d.addCallback(corrupt, self._storage, "pubkey",
1154                       shnums_to_corrupt=range(0, N-k))
1155         d.addCallback(lambda res: self.make_servermap())
1156         def _do_retrieve(servermap):
1157             self.failUnless(servermap.problems)
1158             self.failUnless("pubkey doesn't match fingerprint"
1159                             in str(servermap.problems[0]))
1160             ver = servermap.best_recoverable_version()
1161             r = Retrieve(self._fn, servermap, ver)
1162             return r.download()
1163         d.addCallback(_do_retrieve)
1164         d.addCallback(lambda new_contents:
1165                       self.failUnlessEqual(new_contents, self.CONTENTS))
1166         return d
1167
1168     def test_corrupt_some(self):
1169         # corrupt the data of first five shares (so the servermap thinks
1170         # they're good but retrieve marks them as bad), so that the
1171         # MODE_READ set of 6 will be insufficient, forcing node.download to
1172         # retry with more servers.
1173         corrupt(None, self._storage, "share_data", range(5))
1174         d = self.make_servermap()
1175         def _do_retrieve(servermap):
1176             ver = servermap.best_recoverable_version()
1177             self.failUnless(ver)
1178             return self._fn.download_best_version()
1179         d.addCallback(_do_retrieve)
1180         d.addCallback(lambda new_contents:
1181                       self.failUnlessEqual(new_contents, self.CONTENTS))
1182         return d
1183
1184     def test_download_fails(self):
1185         corrupt(None, self._storage, "signature")
1186         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1187                             "no recoverable versions",
1188                             self._fn.download_best_version)
1189         return d
1190
1191
1192 class CheckerMixin:
1193     def check_good(self, r, where):
1194         self.failUnless(r.is_healthy(), where)
1195         return r
1196
1197     def check_bad(self, r, where):
1198         self.failIf(r.is_healthy(), where)
1199         return r
1200
1201     def check_expected_failure(self, r, expected_exception, substring, where):
1202         for (peerid, storage_index, shnum, f) in r.problems:
1203             if f.check(expected_exception):
1204                 self.failUnless(substring in str(f),
1205                                 "%s: substring '%s' not in '%s'" %
1206                                 (where, substring, str(f)))
1207                 return
1208         self.fail("%s: didn't see expected exception %s in problems %s" %
1209                   (where, expected_exception, r.problems))
1210
1211
1212 class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
1213     def setUp(self):
1214         return self.publish_one()
1215
1216
1217     def test_check_good(self):
1218         d = self._fn.check(Monitor())
1219         d.addCallback(self.check_good, "test_check_good")
1220         return d
1221
1222     def test_check_no_shares(self):
1223         for shares in self._storage._peers.values():
1224             shares.clear()
1225         d = self._fn.check(Monitor())
1226         d.addCallback(self.check_bad, "test_check_no_shares")
1227         return d
1228
1229     def test_check_not_enough_shares(self):
1230         for shares in self._storage._peers.values():
1231             for shnum in shares.keys():
1232                 if shnum > 0:
1233                     del shares[shnum]
1234         d = self._fn.check(Monitor())
1235         d.addCallback(self.check_bad, "test_check_not_enough_shares")
1236         return d
1237
1238     def test_check_all_bad_sig(self):
1239         corrupt(None, self._storage, 1) # bad sig
1240         d = self._fn.check(Monitor())
1241         d.addCallback(self.check_bad, "test_check_all_bad_sig")
1242         return d
1243
1244     def test_check_all_bad_blocks(self):
1245         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1246         # the Checker won't notice this.. it doesn't look at actual data
1247         d = self._fn.check(Monitor())
1248         d.addCallback(self.check_good, "test_check_all_bad_blocks")
1249         return d
1250
1251     def test_verify_good(self):
1252         d = self._fn.check(Monitor(), verify=True)
1253         d.addCallback(self.check_good, "test_verify_good")
1254         return d
1255
1256     def test_verify_all_bad_sig(self):
1257         corrupt(None, self._storage, 1) # bad sig
1258         d = self._fn.check(Monitor(), verify=True)
1259         d.addCallback(self.check_bad, "test_verify_all_bad_sig")
1260         return d
1261
1262     def test_verify_one_bad_sig(self):
1263         corrupt(None, self._storage, 1, [9]) # bad sig
1264         d = self._fn.check(Monitor(), verify=True)
1265         d.addCallback(self.check_bad, "test_verify_one_bad_sig")
1266         return d
1267
1268     def test_verify_one_bad_block(self):
1269         corrupt(None, self._storage, "share_data", [9]) # bad blocks
1270         # the Verifier *will* notice this, since it examines every byte
1271         d = self._fn.check(Monitor(), verify=True)
1272         d.addCallback(self.check_bad, "test_verify_one_bad_block")
1273         d.addCallback(self.check_expected_failure,
1274                       CorruptShareError, "block hash tree failure",
1275                       "test_verify_one_bad_block")
1276         return d
1277
1278     def test_verify_one_bad_sharehash(self):
1279         corrupt(None, self._storage, "share_hash_chain", [9], 5)
1280         d = self._fn.check(Monitor(), verify=True)
1281         d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
1282         d.addCallback(self.check_expected_failure,
1283                       CorruptShareError, "corrupt hashes",
1284                       "test_verify_one_bad_sharehash")
1285         return d
1286
1287     def test_verify_one_bad_encprivkey(self):
1288         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1289         d = self._fn.check(Monitor(), verify=True)
1290         d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
1291         d.addCallback(self.check_expected_failure,
1292                       CorruptShareError, "invalid privkey",
1293                       "test_verify_one_bad_encprivkey")
1294         return d
1295
1296     def test_verify_one_bad_encprivkey_uncheckable(self):
1297         corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
1298         readonly_fn = self._fn.get_readonly()
1299         # a read-only node has no way to validate the privkey
1300         d = readonly_fn.check(Monitor(), verify=True)
1301         d.addCallback(self.check_good,
1302                       "test_verify_one_bad_encprivkey_uncheckable")
1303         return d
1304
1305 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
1306
1307     def get_shares(self, s):
1308         all_shares = {} # maps (peerid, shnum) to share data
1309         for peerid in s._peers:
1310             shares = s._peers[peerid]
1311             for shnum in shares:
1312                 data = shares[shnum]
1313                 all_shares[ (peerid, shnum) ] = data
1314         return all_shares
1315
1316     def copy_shares(self, ignored=None):
1317         self.old_shares.append(self.get_shares(self._storage))
1318
1319     def test_repair_nop(self):
1320         self.old_shares = []
1321         d = self.publish_one()
1322         d.addCallback(self.copy_shares)
1323         d.addCallback(lambda res: self._fn.check(Monitor()))
1324         d.addCallback(lambda check_results: self._fn.repair(check_results))
1325         def _check_results(rres):
1326             self.failUnless(IRepairResults.providedBy(rres))
1327             # TODO: examine results
1328
1329             self.copy_shares()
1330
1331             initial_shares = self.old_shares[0]
1332             new_shares = self.old_shares[1]
1333             # TODO: this really shouldn't change anything. When we implement
1334             # a "minimal-bandwidth" repairer", change this test to assert:
1335             #self.failUnlessEqual(new_shares, initial_shares)
1336
1337             # all shares should be in the same place as before
1338             self.failUnlessEqual(set(initial_shares.keys()),
1339                                  set(new_shares.keys()))
1340             # but they should all be at a newer seqnum. The IV will be
1341             # different, so the roothash will be too.
1342             for key in initial_shares:
1343                 (version0,
1344                  seqnum0,
1345                  root_hash0,
1346                  IV0,
1347                  k0, N0, segsize0, datalen0,
1348                  o0) = unpack_header(initial_shares[key])
1349                 (version1,
1350                  seqnum1,
1351                  root_hash1,
1352                  IV1,
1353                  k1, N1, segsize1, datalen1,
1354                  o1) = unpack_header(new_shares[key])
1355                 self.failUnlessEqual(version0, version1)
1356                 self.failUnlessEqual(seqnum0+1, seqnum1)
1357                 self.failUnlessEqual(k0, k1)
1358                 self.failUnlessEqual(N0, N1)
1359                 self.failUnlessEqual(segsize0, segsize1)
1360                 self.failUnlessEqual(datalen0, datalen1)
1361         d.addCallback(_check_results)
1362         return d
1363
1364     def failIfSharesChanged(self, ignored=None):
1365         old_shares = self.old_shares[-2]
1366         current_shares = self.old_shares[-1]
1367         self.failUnlessEqual(old_shares, current_shares)
1368
1369     def test_merge(self):
1370         self.old_shares = []
1371         d = self.publish_multiple()
1372         # repair will refuse to merge multiple highest seqnums unless you
1373         # pass force=True
1374         d.addCallback(lambda res:
1375                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1376                                           1:4,3:4,5:4,7:4,9:4}))
1377         d.addCallback(self.copy_shares)
1378         d.addCallback(lambda res: self._fn.check(Monitor()))
1379         def _try_repair(check_results):
1380             ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
1381             d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
1382                                  self._fn.repair, check_results)
1383             d2.addCallback(self.copy_shares)
1384             d2.addCallback(self.failIfSharesChanged)
1385             d2.addCallback(lambda res: check_results)
1386             return d2
1387         d.addCallback(_try_repair)
1388         d.addCallback(lambda check_results:
1389                       self._fn.repair(check_results, force=True))
1390         # this should give us 10 shares of the highest roothash
1391         def _check_repair_results(rres):
1392             pass # TODO
1393         d.addCallback(_check_repair_results)
1394         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1395         def _check_smap(smap):
1396             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1397             self.failIf(smap.unrecoverable_versions())
1398             # now, which should have won?
1399             roothash_s4a = self.get_roothash_for(3)
1400             roothash_s4b = self.get_roothash_for(4)
1401             if roothash_s4b > roothash_s4a:
1402                 expected_contents = self.CONTENTS[4]
1403             else:
1404                 expected_contents = self.CONTENTS[3]
1405             new_versionid = smap.best_recoverable_version()
1406             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1407             d2 = self._fn.download_version(smap, new_versionid)
1408             d2.addCallback(self.failUnlessEqual, expected_contents)
1409             return d2
1410         d.addCallback(_check_smap)
1411         return d
1412
1413     def test_non_merge(self):
1414         self.old_shares = []
1415         d = self.publish_multiple()
1416         # repair should not refuse a repair that doesn't need to merge. In
1417         # this case, we combine v2 with v3. The repair should ignore v2 and
1418         # copy v3 into a new v5.
1419         d.addCallback(lambda res:
1420                       self._set_versions({0:2,2:2,4:2,6:2,8:2,
1421                                           1:3,3:3,5:3,7:3,9:3}))
1422         d.addCallback(lambda res: self._fn.check(Monitor()))
1423         d.addCallback(lambda check_results: self._fn.repair(check_results))
1424         # this should give us 10 shares of v3
1425         def _check_repair_results(rres):
1426             pass # TODO
1427         d.addCallback(_check_repair_results)
1428         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1429         def _check_smap(smap):
1430             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1431             self.failIf(smap.unrecoverable_versions())
1432             # now, which should have won?
1433             roothash_s4a = self.get_roothash_for(3)
1434             expected_contents = self.CONTENTS[3]
1435             new_versionid = smap.best_recoverable_version()
1436             self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
1437             d2 = self._fn.download_version(smap, new_versionid)
1438             d2.addCallback(self.failUnlessEqual, expected_contents)
1439             return d2
1440         d.addCallback(_check_smap)
1441         return d
1442
1443     def get_roothash_for(self, index):
1444         # return the roothash for the first share we see in the saved set
1445         shares = self._copied_shares[index]
1446         for peerid in shares:
1447             for shnum in shares[peerid]:
1448                 share = shares[peerid][shnum]
1449                 (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
1450                           unpack_header(share)
1451                 return root_hash
1452
1453 class MultipleEncodings(unittest.TestCase):
1454     def setUp(self):
1455         self.CONTENTS = "New contents go here"
1456         num_peers = 20
1457         self._client = FakeClient(num_peers)
1458         self._storage = self._client._storage
1459         d = self._client.create_mutable_file(self.CONTENTS)
1460         def _created(node):
1461             self._fn = node
1462         d.addCallback(_created)
1463         return d
1464
1465     def _encode(self, k, n, data):
1466         # encode 'data' into a peerid->shares dict.
1467
1468         fn2 = FastMutableFileNode(self._client)
1469         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1470         # and _fingerprint
1471         fn = self._fn
1472         fn2.init_from_uri(fn.get_uri())
1473         # then we copy over other fields that are normally fetched from the
1474         # existing shares
1475         fn2._pubkey = fn._pubkey
1476         fn2._privkey = fn._privkey
1477         fn2._encprivkey = fn._encprivkey
1478         # and set the encoding parameters to something completely different
1479         fn2._required_shares = k
1480         fn2._total_shares = n
1481
1482         s = self._client._storage
1483         s._peers = {} # clear existing storage
1484         p2 = Publish(fn2, None)
1485         d = p2.publish(data)
1486         def _published(res):
1487             shares = s._peers
1488             s._peers = {}
1489             return shares
1490         d.addCallback(_published)
1491         return d
1492
1493     def make_servermap(self, mode=MODE_READ, oldmap=None):
1494         if oldmap is None:
1495             oldmap = ServerMap()
1496         smu = ServermapUpdater(self._fn, Monitor(), oldmap, mode)
1497         d = smu.update()
1498         return d
1499
1500     def test_multiple_encodings(self):
1501         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1502         # then mix up the shares, to make sure that download survives seeing
1503         # a variety of encodings. This is actually kind of tricky to set up.
1504
1505         contents1 = "Contents for encoding 1 (3-of-10) go here"
1506         contents2 = "Contents for encoding 2 (4-of-9) go here"
1507         contents3 = "Contents for encoding 3 (4-of-7) go here"
1508
1509         # we make a retrieval object that doesn't know what encoding
1510         # parameters to use
1511         fn3 = FastMutableFileNode(self._client)
1512         fn3.init_from_uri(self._fn.get_uri())
1513
1514         # now we upload a file through fn1, and grab its shares
1515         d = self._encode(3, 10, contents1)
1516         def _encoded_1(shares):
1517             self._shares1 = shares
1518         d.addCallback(_encoded_1)
1519         d.addCallback(lambda res: self._encode(4, 9, contents2))
1520         def _encoded_2(shares):
1521             self._shares2 = shares
1522         d.addCallback(_encoded_2)
1523         d.addCallback(lambda res: self._encode(4, 7, contents3))
1524         def _encoded_3(shares):
1525             self._shares3 = shares
1526         d.addCallback(_encoded_3)
1527
1528         def _merge(res):
1529             log.msg("merging sharelists")
1530             # we merge the shares from the two sets, leaving each shnum in
1531             # its original location, but using a share from set1 or set2
1532             # according to the following sequence:
1533             #
1534             #  4-of-9  a  s2
1535             #  4-of-9  b  s2
1536             #  4-of-7  c   s3
1537             #  4-of-9  d  s2
1538             #  3-of-9  e s1
1539             #  3-of-9  f s1
1540             #  3-of-9  g s1
1541             #  4-of-9  h  s2
1542             #
1543             # so that neither form can be recovered until fetch [f], at which
1544             # point version-s1 (the 3-of-10 form) should be recoverable. If
1545             # the implementation latches on to the first version it sees,
1546             # then s2 will be recoverable at fetch [g].
1547
1548             # Later, when we implement code that handles multiple versions,
1549             # we can use this framework to assert that all recoverable
1550             # versions are retrieved, and test that 'epsilon' does its job
1551
1552             places = [2, 2, 3, 2, 1, 1, 1, 2]
1553
1554             sharemap = {}
1555
1556             for i,peerid in enumerate(self._client._peerids):
1557                 peerid_s = shortnodeid_b2a(peerid)
1558                 for shnum in self._shares1.get(peerid, {}):
1559                     if shnum < len(places):
1560                         which = places[shnum]
1561                     else:
1562                         which = "x"
1563                     self._client._storage._peers[peerid] = peers = {}
1564                     in_1 = shnum in self._shares1[peerid]
1565                     in_2 = shnum in self._shares2.get(peerid, {})
1566                     in_3 = shnum in self._shares3.get(peerid, {})
1567                     #print peerid_s, shnum, which, in_1, in_2, in_3
1568                     if which == 1:
1569                         if in_1:
1570                             peers[shnum] = self._shares1[peerid][shnum]
1571                             sharemap[shnum] = peerid
1572                     elif which == 2:
1573                         if in_2:
1574                             peers[shnum] = self._shares2[peerid][shnum]
1575                             sharemap[shnum] = peerid
1576                     elif which == 3:
1577                         if in_3:
1578                             peers[shnum] = self._shares3[peerid][shnum]
1579                             sharemap[shnum] = peerid
1580
1581             # we don't bother placing any other shares
1582             # now sort the sequence so that share 0 is returned first
1583             new_sequence = [sharemap[shnum]
1584                             for shnum in sorted(sharemap.keys())]
1585             self._client._storage._sequence = new_sequence
1586             log.msg("merge done")
1587         d.addCallback(_merge)
1588         d.addCallback(lambda res: fn3.download_best_version())
1589         def _retrieved(new_contents):
1590             # the current specified behavior is "first version recoverable"
1591             self.failUnlessEqual(new_contents, contents1)
1592         d.addCallback(_retrieved)
1593         return d
1594
1595
1596 class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
1597
1598     def setUp(self):
1599         return self.publish_multiple()
1600
1601     def test_multiple_versions(self):
1602         # if we see a mix of versions in the grid, download_best_version
1603         # should get the latest one
1604         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1605         d = self._fn.download_best_version()
1606         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1607         # and the checker should report problems
1608         d.addCallback(lambda res: self._fn.check(Monitor()))
1609         d.addCallback(self.check_bad, "test_multiple_versions")
1610
1611         # but if everything is at version 2, that's what we should download
1612         d.addCallback(lambda res:
1613                       self._set_versions(dict([(i,2) for i in range(10)])))
1614         d.addCallback(lambda res: self._fn.download_best_version())
1615         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1616         # if exactly one share is at version 3, we should still get v2
1617         d.addCallback(lambda res:
1618                       self._set_versions({0:3}))
1619         d.addCallback(lambda res: self._fn.download_best_version())
1620         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1621         # but the servermap should see the unrecoverable version. This
1622         # depends upon the single newer share being queried early.
1623         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1624         def _check_smap(smap):
1625             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1626             newer = smap.unrecoverable_newer_versions()
1627             self.failUnlessEqual(len(newer), 1)
1628             verinfo, health = newer.items()[0]
1629             self.failUnlessEqual(verinfo[0], 4)
1630             self.failUnlessEqual(health, (1,3))
1631             self.failIf(smap.needs_merge())
1632         d.addCallback(_check_smap)
1633         # if we have a mix of two parallel versions (s4a and s4b), we could
1634         # recover either
1635         d.addCallback(lambda res:
1636                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1637                                           1:4,3:4,5:4,7:4,9:4}))
1638         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1639         def _check_smap_mixed(smap):
1640             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1641             newer = smap.unrecoverable_newer_versions()
1642             self.failUnlessEqual(len(newer), 0)
1643             self.failUnless(smap.needs_merge())
1644         d.addCallback(_check_smap_mixed)
1645         d.addCallback(lambda res: self._fn.download_best_version())
1646         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1647                                                   res == self.CONTENTS[4]))
1648         return d
1649
1650     def test_replace(self):
1651         # if we see a mix of versions in the grid, we should be able to
1652         # replace them all with a newer version
1653
1654         # if exactly one share is at version 3, we should download (and
1655         # replace) v2, and the result should be v4. Note that the index we
1656         # give to _set_versions is different than the sequence number.
1657         target = dict([(i,2) for i in range(10)]) # seqnum3
1658         target[0] = 3 # seqnum4
1659         self._set_versions(target)
1660
1661         def _modify(oldversion):
1662             return oldversion + " modified"
1663         d = self._fn.modify(_modify)
1664         d.addCallback(lambda res: self._fn.download_best_version())
1665         expected = self.CONTENTS[2] + " modified"
1666         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1667         # and the servermap should indicate that the outlier was replaced too
1668         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1669         def _check_smap(smap):
1670             self.failUnlessEqual(smap.highest_seqnum(), 5)
1671             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1672             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1673         d.addCallback(_check_smap)
1674         return d
1675
1676
1677 class Utils(unittest.TestCase):
1678     def test_dict_of_sets(self):
1679         ds = DictOfSets()
1680         ds.add(1, "a")
1681         ds.add(2, "b")
1682         ds.add(2, "b")
1683         ds.add(2, "c")
1684         self.failUnlessEqual(ds[1], set(["a"]))
1685         self.failUnlessEqual(ds[2], set(["b", "c"]))
1686         ds.discard(3, "d") # should not raise an exception
1687         ds.discard(2, "b")
1688         self.failUnlessEqual(ds[2], set(["c"]))
1689         ds.discard(2, "c")
1690         self.failIf(2 in ds)
1691
1692     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1693         # we compare this against sets of integers
1694         x = set(range(x_start, x_start+x_length))
1695         y = set(range(y_start, y_start+y_length))
1696         should_be_inside = x.issubset(y)
1697         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1698                                                          y_start, y_length),
1699                              str((x_start, x_length, y_start, y_length)))
1700
1701     def test_cache_inside(self):
1702         c = ResponseCache()
1703         x_start = 10
1704         x_length = 5
1705         for y_start in range(8, 17):
1706             for y_length in range(8):
1707                 self._do_inside(c, x_start, x_length, y_start, y_length)
1708
1709     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1710         # we compare this against sets of integers
1711         x = set(range(x_start, x_start+x_length))
1712         y = set(range(y_start, y_start+y_length))
1713         overlap = bool(x.intersection(y))
1714         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1715                                                       y_start, y_length),
1716                              str((x_start, x_length, y_start, y_length)))
1717
1718     def test_cache_overlap(self):
1719         c = ResponseCache()
1720         x_start = 10
1721         x_length = 5
1722         for y_start in range(8, 17):
1723             for y_length in range(8):
1724                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1725
1726     def test_cache(self):
1727         c = ResponseCache()
1728         # xdata = base62.b2a(os.urandom(100))[:100]
1729         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1730         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1731         nope = (None, None)
1732         c.add("v1", 1, 0, xdata, "time0")
1733         c.add("v1", 1, 2000, ydata, "time1")
1734         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1735         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1736         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1737         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1738         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1739         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1740         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1741         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1742         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1743         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1744         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1745         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1746         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1747         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1748         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1749         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1750         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1751         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1752
1753         # optional: join fragments
1754         c = ResponseCache()
1755         c.add("v1", 1, 0, xdata[:10], "time0")
1756         c.add("v1", 1, 10, xdata[10:20], "time1")
1757         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1758
1759 class Exceptions(unittest.TestCase):
1760     def test_repr(self):
1761         nmde = NeedMoreDataError(100, 50, 100)
1762         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1763         ucwe = UncoordinatedWriteError()
1764         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1765
1766 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1767 # instances which always succeed. So we need a less-fake one.
1768
1769 class IntentionalError(Exception):
1770     pass
1771
1772 class LocalWrapper:
1773     def __init__(self, original):
1774         self.original = original
1775         self.broken = False
1776         self.post_call_notifier = None
1777     def callRemote(self, methname, *args, **kwargs):
1778         def _call():
1779             if self.broken:
1780                 raise IntentionalError("I was asked to break")
1781             meth = getattr(self.original, "remote_" + methname)
1782             return meth(*args, **kwargs)
1783         d = fireEventually()
1784         d.addCallback(lambda res: _call())
1785         if self.post_call_notifier:
1786             d.addCallback(self.post_call_notifier, methname)
1787         return d
1788
1789 class LessFakeClient(FakeClient):
1790
1791     def __init__(self, basedir, num_peers=10):
1792         self._num_peers = num_peers
1793         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1794                          for i in range(self._num_peers)]
1795         self._connections = {}
1796         for peerid in self._peerids:
1797             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1798             make_dirs(peerdir)
1799             ss = storage.StorageServer(peerdir)
1800             ss.setNodeID(peerid)
1801             lw = LocalWrapper(ss)
1802             self._connections[peerid] = lw
1803         self.nodeid = "fakenodeid"
1804
1805
1806 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1807     def test_publish_surprise(self):
1808         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1809         self.client = LessFakeClient(basedir)
1810         d = self.client.create_mutable_file("contents 1")
1811         def _created(n):
1812             d = defer.succeed(None)
1813             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1814             def _got_smap1(smap):
1815                 # stash the old state of the file
1816                 self.old_map = smap
1817             d.addCallback(_got_smap1)
1818             # then modify the file, leaving the old map untouched
1819             d.addCallback(lambda res: log.msg("starting winning write"))
1820             d.addCallback(lambda res: n.overwrite("contents 2"))
1821             # now attempt to modify the file with the old servermap. This
1822             # will look just like an uncoordinated write, in which every
1823             # single share got updated between our mapupdate and our publish
1824             d.addCallback(lambda res: log.msg("starting doomed write"))
1825             d.addCallback(lambda res:
1826                           self.shouldFail(UncoordinatedWriteError,
1827                                           "test_publish_surprise", None,
1828                                           n.upload,
1829                                           "contents 2a", self.old_map))
1830             return d
1831         d.addCallback(_created)
1832         return d
1833
1834     def test_retrieve_surprise(self):
1835         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1836         self.client = LessFakeClient(basedir)
1837         d = self.client.create_mutable_file("contents 1")
1838         def _created(n):
1839             d = defer.succeed(None)
1840             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1841             def _got_smap1(smap):
1842                 # stash the old state of the file
1843                 self.old_map = smap
1844             d.addCallback(_got_smap1)
1845             # then modify the file, leaving the old map untouched
1846             d.addCallback(lambda res: log.msg("starting winning write"))
1847             d.addCallback(lambda res: n.overwrite("contents 2"))
1848             # now attempt to retrieve the old version with the old servermap.
1849             # This will look like someone has changed the file since we
1850             # updated the servermap.
1851             d.addCallback(lambda res: n._cache._clear())
1852             d.addCallback(lambda res: log.msg("starting doomed read"))
1853             d.addCallback(lambda res:
1854                           self.shouldFail(NotEnoughSharesError,
1855                                           "test_retrieve_surprise",
1856                                           "ran out of peers: have 0 shares (k=3)",
1857                                           n.download_version,
1858                                           self.old_map,
1859                                           self.old_map.best_recoverable_version(),
1860                                           ))
1861             return d
1862         d.addCallback(_created)
1863         return d
1864
1865     def test_unexpected_shares(self):
1866         # upload the file, take a servermap, shut down one of the servers,
1867         # upload it again (causing shares to appear on a new server), then
1868         # upload using the old servermap. The last upload should fail with an
1869         # UncoordinatedWriteError, because of the shares that didn't appear
1870         # in the servermap.
1871         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1872         self.client = LessFakeClient(basedir)
1873         d = self.client.create_mutable_file("contents 1")
1874         def _created(n):
1875             d = defer.succeed(None)
1876             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1877             def _got_smap1(smap):
1878                 # stash the old state of the file
1879                 self.old_map = smap
1880                 # now shut down one of the servers
1881                 peer0 = list(smap.make_sharemap()[0])[0]
1882                 self.client._connections.pop(peer0)
1883                 # then modify the file, leaving the old map untouched
1884                 log.msg("starting winning write")
1885                 return n.overwrite("contents 2")
1886             d.addCallback(_got_smap1)
1887             # now attempt to modify the file with the old servermap. This
1888             # will look just like an uncoordinated write, in which every
1889             # single share got updated between our mapupdate and our publish
1890             d.addCallback(lambda res: log.msg("starting doomed write"))
1891             d.addCallback(lambda res:
1892                           self.shouldFail(UncoordinatedWriteError,
1893                                           "test_surprise", None,
1894                                           n.upload,
1895                                           "contents 2a", self.old_map))
1896             return d
1897         d.addCallback(_created)
1898         return d
1899
1900     def test_bad_server(self):
1901         # Break one server, then create the file: the initial publish should
1902         # complete with an alternate server. Breaking a second server should
1903         # not prevent an update from succeeding either.
1904         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1905         self.client = LessFakeClient(basedir, 20)
1906         # to make sure that one of the initial peers is broken, we have to
1907         # get creative. We create the keys, so we can figure out the storage
1908         # index, but we hold off on doing the initial publish until we've
1909         # broken the server on which the first share wants to be stored.
1910         n = FastMutableFileNode(self.client)
1911         d = defer.succeed(None)
1912         d.addCallback(n._generate_pubprivkeys)
1913         d.addCallback(n._generated)
1914         def _break_peer0(res):
1915             si = n.get_storage_index()
1916             peerlist = self.client.get_permuted_peers("storage", si)
1917             peerid0, connection0 = peerlist[0]
1918             peerid1, connection1 = peerlist[1]
1919             connection0.broken = True
1920             self.connection1 = connection1
1921         d.addCallback(_break_peer0)
1922         # now let the initial publish finally happen
1923         d.addCallback(lambda res: n._upload("contents 1", None))
1924         # that ought to work
1925         d.addCallback(lambda res: n.download_best_version())
1926         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1927         # now break the second peer
1928         def _break_peer1(res):
1929             self.connection1.broken = True
1930         d.addCallback(_break_peer1)
1931         d.addCallback(lambda res: n.overwrite("contents 2"))
1932         # that ought to work too
1933         d.addCallback(lambda res: n.download_best_version())
1934         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1935         return d
1936
1937     def test_publish_all_servers_bad(self):
1938         # Break all servers: the publish should fail
1939         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1940         self.client = LessFakeClient(basedir, 20)
1941         for connection in self.client._connections.values():
1942             connection.broken = True
1943         d = self.shouldFail(NotEnoughServersError,
1944                             "test_publish_all_servers_bad",
1945                             "Ran out of non-bad servers",
1946                             self.client.create_mutable_file, "contents")
1947         return d
1948
1949     def test_publish_no_servers(self):
1950         # no servers at all: the publish should fail
1951         basedir = os.path.join("mutable/CollidingWrites/publish_no_servers")
1952         self.client = LessFakeClient(basedir, 0)
1953         d = self.shouldFail(NotEnoughServersError,
1954                             "test_publish_no_servers",
1955                             "Ran out of non-bad servers",
1956                             self.client.create_mutable_file, "contents")
1957         return d
1958     test_publish_no_servers.timeout = 30
1959
1960
1961     def test_privkey_query_error(self):
1962         # when a servermap is updated with MODE_WRITE, it tries to get the
1963         # privkey. Something might go wrong during this query attempt.
1964         self.client = FakeClient(20)
1965         # we need some contents that are large enough to push the privkey out
1966         # of the early part of the file
1967         LARGE = "These are Larger contents" * 200 # about 5KB
1968         d = self.client.create_mutable_file(LARGE)
1969         def _created(n):
1970             self.uri = n.get_uri()
1971             self.n2 = self.client.create_node_from_uri(self.uri)
1972             # we start by doing a map update to figure out which is the first
1973             # server.
1974             return n.get_servermap(MODE_WRITE)
1975         d.addCallback(_created)
1976         d.addCallback(lambda res: fireEventually(res))
1977         def _got_smap1(smap):
1978             peer0 = list(smap.make_sharemap()[0])[0]
1979             # we tell the server to respond to this peer first, so that it
1980             # will be asked for the privkey first
1981             self.client._storage._sequence = [peer0]
1982             # now we make the peer fail their second query
1983             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1984         d.addCallback(_got_smap1)
1985         # now we update a servermap from a new node (which doesn't have the
1986         # privkey yet, forcing it to use a separate privkey query). Each
1987         # query response will trigger a privkey query, and since we're using
1988         # _sequence to make the peer0 response come back first, we'll send it
1989         # a privkey query first, and _sequence will again ensure that the
1990         # peer0 query will also come back before the others, and then
1991         # _special_answers will make sure that the query raises an exception.
1992         # The whole point of these hijinks is to exercise the code in
1993         # _privkey_query_failed. Note that the map-update will succeed, since
1994         # we'll just get a copy from one of the other shares.
1995         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1996         # Using FakeStorage._sequence means there will be read requests still
1997         # floating around.. wait for them to retire
1998         def _cancel_timer(res):
1999             if self.client._storage._pending_timer:
2000                 self.client._storage._pending_timer.cancel()
2001             return res
2002         d.addBoth(_cancel_timer)
2003         return d
2004
2005     def test_privkey_query_missing(self):
2006         # like test_privkey_query_error, but the shares are deleted by the
2007         # second query, instead of raising an exception.
2008         self.client = FakeClient(20)
2009         LARGE = "These are Larger contents" * 200 # about 5KB
2010         d = self.client.create_mutable_file(LARGE)
2011         def _created(n):
2012             self.uri = n.get_uri()
2013             self.n2 = self.client.create_node_from_uri(self.uri)
2014             return n.get_servermap(MODE_WRITE)
2015         d.addCallback(_created)
2016         d.addCallback(lambda res: fireEventually(res))
2017         def _got_smap1(smap):
2018             peer0 = list(smap.make_sharemap()[0])[0]
2019             self.client._storage._sequence = [peer0]
2020             self.client._storage._special_answers[peer0] = ["normal", "none"]
2021         d.addCallback(_got_smap1)
2022         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
2023         def _cancel_timer(res):
2024             if self.client._storage._pending_timer:
2025                 self.client._storage._pending_timer.cancel()
2026             return res
2027         d.addBoth(_cancel_timer)
2028         return d