]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
test_mutable.py: test replacing a file that has one new outlier share present: closes...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import os, struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, download, storage
8 from allmydata.util import base32, testutil, idlib
9 from allmydata.util.idlib import shortnodeid_b2a
10 from allmydata.util.hashutil import tagged_hash
11 from allmydata.util.fileutil import make_dirs
12 from allmydata.encode import NotEnoughSharesError
13 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable
14 from foolscap.eventual import eventually, fireEventually
15 from foolscap.logging import log
16 import sha
17
18 from allmydata.mutable.node import MutableFileNode, BackoffAgent
19 from allmydata.mutable.common import DictOfSets, ResponseCache, \
20      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
21      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
22      NotEnoughServersError
23 from allmydata.mutable.retrieve import Retrieve
24 from allmydata.mutable.publish import Publish
25 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
26 from allmydata.mutable.layout import unpack_header, unpack_share
27
28 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
29 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
30 # of this.
31
32 class FastMutableFileNode(MutableFileNode):
33     SIGNATURE_KEY_SIZE = 522
34
35 # this "FakeStorage" exists to put the share data in RAM and avoid using real
36 # network connections, both to speed up the tests and to reduce the amount of
37 # non-mutable.py code being exercised.
38
39 class FakeStorage:
40     # this class replaces the collection of storage servers, allowing the
41     # tests to examine and manipulate the published shares. It also lets us
42     # control the order in which read queries are answered, to exercise more
43     # of the error-handling code in Retrieve .
44     #
45     # Note that we ignore the storage index: this FakeStorage instance can
46     # only be used for a single storage index.
47
48
49     def __init__(self):
50         self._peers = {}
51         # _sequence is used to cause the responses to occur in a specific
52         # order. If it is in use, then we will defer queries instead of
53         # answering them right away, accumulating the Deferreds in a dict. We
54         # don't know exactly how many queries we'll get, so exactly one
55         # second after the first query arrives, we will release them all (in
56         # order).
57         self._sequence = None
58         self._pending = {}
59         self._pending_timer = None
60         self._special_answers = {}
61
62     def read(self, peerid, storage_index):
63         shares = self._peers.get(peerid, {})
64         if self._special_answers.get(peerid, []):
65             mode = self._special_answers[peerid].pop(0)
66             if mode == "fail":
67                 shares = failure.Failure(IntentionalError())
68             elif mode == "none":
69                 shares = {}
70             elif mode == "normal":
71                 pass
72         if self._sequence is None:
73             return defer.succeed(shares)
74         d = defer.Deferred()
75         if not self._pending:
76             self._pending_timer = reactor.callLater(1.0, self._fire_readers)
77         self._pending[peerid] = (d, shares)
78         return d
79
80     def _fire_readers(self):
81         self._pending_timer = None
82         pending = self._pending
83         self._pending = {}
84         extra = []
85         for peerid in self._sequence:
86             if peerid in pending:
87                 d, shares = pending.pop(peerid)
88                 eventually(d.callback, shares)
89         for (d, shares) in pending.values():
90             eventually(d.callback, shares)
91
92     def write(self, peerid, storage_index, shnum, offset, data):
93         if peerid not in self._peers:
94             self._peers[peerid] = {}
95         shares = self._peers[peerid]
96         f = StringIO()
97         f.write(shares.get(shnum, ""))
98         f.seek(offset)
99         f.write(data)
100         shares[shnum] = f.getvalue()
101
102
103 class FakeStorageServer:
104     def __init__(self, peerid, storage):
105         self.peerid = peerid
106         self.storage = storage
107         self.queries = 0
108     def callRemote(self, methname, *args, **kwargs):
109         def _call():
110             meth = getattr(self, methname)
111             return meth(*args, **kwargs)
112         d = fireEventually()
113         d.addCallback(lambda res: _call())
114         return d
115
116     def slot_readv(self, storage_index, shnums, readv):
117         d = self.storage.read(self.peerid, storage_index)
118         def _read(shares):
119             response = {}
120             for shnum in shares:
121                 if shnums and shnum not in shnums:
122                     continue
123                 vector = response[shnum] = []
124                 for (offset, length) in readv:
125                     assert isinstance(offset, (int, long)), offset
126                     assert isinstance(length, (int, long)), length
127                     vector.append(shares[shnum][offset:offset+length])
128             return response
129         d.addCallback(_read)
130         return d
131
132     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
133                                         tw_vectors, read_vector):
134         # always-pass: parrot the test vectors back to them.
135         readv = {}
136         for shnum, (testv, writev, new_length) in tw_vectors.items():
137             for (offset, length, op, specimen) in testv:
138                 assert op in ("le", "eq", "ge")
139             # TODO: this isn't right, the read is controlled by read_vector,
140             # not by testv
141             readv[shnum] = [ specimen
142                              for (offset, length, op, specimen)
143                              in testv ]
144             for (offset, data) in writev:
145                 self.storage.write(self.peerid, storage_index, shnum,
146                                    offset, data)
147         answer = (True, readv)
148         return fireEventually(answer)
149
150
151 # our "FakeClient" has just enough functionality of the real Client to let
152 # the tests run.
153
154 class FakeClient:
155     mutable_file_node_class = FastMutableFileNode
156
157     def __init__(self, num_peers=10):
158         self._storage = FakeStorage()
159         self._num_peers = num_peers
160         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
161                          for i in range(self._num_peers)]
162         self._connections = dict([(peerid, FakeStorageServer(peerid,
163                                                              self._storage))
164                                   for peerid in self._peerids])
165         self.nodeid = "fakenodeid"
166
167     def log(self, msg, **kw):
168         return log.msg(msg, **kw)
169
170     def get_renewal_secret(self):
171         return "I hereby permit you to renew my files"
172     def get_cancel_secret(self):
173         return "I hereby permit you to cancel my leases"
174
175     def create_mutable_file(self, contents=""):
176         n = self.mutable_file_node_class(self)
177         d = n.create(contents)
178         d.addCallback(lambda res: n)
179         return d
180
181     def notify_retrieve(self, r):
182         pass
183     def notify_publish(self, p, size):
184         pass
185     def notify_mapupdate(self, u):
186         pass
187
188     def create_node_from_uri(self, u):
189         u = IURI(u)
190         assert IMutableFileURI.providedBy(u), u
191         res = self.mutable_file_node_class(self).init_from_uri(u)
192         return res
193
194     def get_permuted_peers(self, service_name, key):
195         """
196         @return: list of (peerid, connection,)
197         """
198         results = []
199         for (peerid, connection) in self._connections.items():
200             assert isinstance(peerid, str)
201             permuted = sha.new(key + peerid).digest()
202             results.append((permuted, peerid, connection))
203         results.sort()
204         results = [ (r[1],r[2]) for r in results]
205         return results
206
207     def upload(self, uploadable):
208         assert IUploadable.providedBy(uploadable)
209         d = uploadable.get_size()
210         d.addCallback(lambda length: uploadable.read(length))
211         #d.addCallback(self.create_mutable_file)
212         def _got_data(datav):
213             data = "".join(datav)
214             #newnode = FastMutableFileNode(self)
215             return uri.LiteralFileURI(data)
216         d.addCallback(_got_data)
217         return d
218
219
220 def flip_bit(original, byte_offset):
221     return (original[:byte_offset] +
222             chr(ord(original[byte_offset]) ^ 0x01) +
223             original[byte_offset+1:])
224
225 def corrupt(res, s, offset, shnums_to_corrupt=None):
226     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
227     # list of shnums to corrupt.
228     for peerid in s._peers:
229         shares = s._peers[peerid]
230         for shnum in shares:
231             if (shnums_to_corrupt is not None
232                 and shnum not in shnums_to_corrupt):
233                 continue
234             data = shares[shnum]
235             (version,
236              seqnum,
237              root_hash,
238              IV,
239              k, N, segsize, datalen,
240              o) = unpack_header(data)
241             if isinstance(offset, tuple):
242                 offset1, offset2 = offset
243             else:
244                 offset1 = offset
245                 offset2 = 0
246             if offset1 == "pubkey":
247                 real_offset = 107
248             elif offset1 in o:
249                 real_offset = o[offset1]
250             else:
251                 real_offset = offset1
252             real_offset = int(real_offset) + offset2
253             assert isinstance(real_offset, int), offset
254             shares[shnum] = flip_bit(data, real_offset)
255     return res
256
257 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
258     def setUp(self):
259         self.client = FakeClient()
260
261     def test_create(self):
262         d = self.client.create_mutable_file()
263         def _created(n):
264             self.failUnless(isinstance(n, FastMutableFileNode))
265             peer0 = self.client._peerids[0]
266             shnums = self.client._storage._peers[peer0].keys()
267             self.failUnlessEqual(len(shnums), 1)
268         d.addCallback(_created)
269         return d
270
271     def test_serialize(self):
272         n = MutableFileNode(self.client)
273         calls = []
274         def _callback(*args, **kwargs):
275             self.failUnlessEqual(args, (4,) )
276             self.failUnlessEqual(kwargs, {"foo": 5})
277             calls.append(1)
278             return 6
279         d = n._do_serialized(_callback, 4, foo=5)
280         def _check_callback(res):
281             self.failUnlessEqual(res, 6)
282             self.failUnlessEqual(calls, [1])
283         d.addCallback(_check_callback)
284
285         def _errback():
286             raise ValueError("heya")
287         d.addCallback(lambda res:
288                       self.shouldFail(ValueError, "_check_errback", "heya",
289                                       n._do_serialized, _errback))
290         return d
291
292     def test_upload_and_download(self):
293         d = self.client.create_mutable_file()
294         def _created(n):
295             d = defer.succeed(None)
296             d.addCallback(lambda res: n.get_servermap(MODE_READ))
297             d.addCallback(lambda smap: smap.dump(StringIO()))
298             d.addCallback(lambda sio:
299                           self.failUnless("3-of-10" in sio.getvalue()))
300             d.addCallback(lambda res: n.overwrite("contents 1"))
301             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
302             d.addCallback(lambda res: n.download_best_version())
303             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
304             d.addCallback(lambda res: n.overwrite("contents 2"))
305             d.addCallback(lambda res: n.download_best_version())
306             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
307             d.addCallback(lambda res: n.download(download.Data()))
308             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
309             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
310             d.addCallback(lambda smap: n.upload("contents 3", smap))
311             d.addCallback(lambda res: n.download_best_version())
312             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
313             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
314             d.addCallback(lambda smap:
315                           n.download_version(smap,
316                                              smap.best_recoverable_version()))
317             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
318             # test a file that is large enough to overcome the
319             # mapupdate-to-retrieve data caching (i.e. make the shares larger
320             # than the default readsize, which is 2000 bytes). A 15kB file
321             # will have 5kB shares.
322             d.addCallback(lambda res: n.overwrite("large size file" * 1000))
323             d.addCallback(lambda res: n.download_best_version())
324             d.addCallback(lambda res:
325                           self.failUnlessEqual(res, "large size file" * 1000))
326             return d
327         d.addCallback(_created)
328         return d
329
330     def test_create_with_initial_contents(self):
331         d = self.client.create_mutable_file("contents 1")
332         def _created(n):
333             d = n.download_best_version()
334             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
335             d.addCallback(lambda res: n.overwrite("contents 2"))
336             d.addCallback(lambda res: n.download_best_version())
337             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
338             return d
339         d.addCallback(_created)
340         return d
341
342     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
343         d = n.get_servermap(MODE_READ)
344         d.addCallback(lambda servermap: servermap.best_recoverable_version())
345         d.addCallback(lambda verinfo:
346                       self.failUnlessEqual(verinfo[0], expected_seqnum))
347         return d
348
349     def test_modify(self):
350         def _modifier(old_contents):
351             return old_contents + "line2"
352         def _non_modifier(old_contents):
353             return old_contents
354         def _none_modifier(old_contents):
355             return None
356         def _error_modifier(old_contents):
357             raise ValueError("oops")
358         calls = []
359         def _ucw_error_modifier(old_contents):
360             # simulate an UncoordinatedWriteError once
361             calls.append(1)
362             if len(calls) <= 1:
363                 raise UncoordinatedWriteError("simulated")
364             return old_contents + "line3"
365
366         d = self.client.create_mutable_file("line1")
367         def _created(n):
368             d = n.modify(_modifier)
369             d.addCallback(lambda res: n.download_best_version())
370             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
371             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
372
373             d.addCallback(lambda res: n.modify(_non_modifier))
374             d.addCallback(lambda res: n.download_best_version())
375             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
376             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
377
378             d.addCallback(lambda res: n.modify(_none_modifier))
379             d.addCallback(lambda res: n.download_best_version())
380             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
381             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
382
383             d.addCallback(lambda res:
384                           self.shouldFail(ValueError, "error_modifier", None,
385                                           n.modify, _error_modifier))
386             d.addCallback(lambda res: n.download_best_version())
387             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
388             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
389
390             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
391             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
392             d.addCallback(lambda res: n.download_best_version())
393             d.addCallback(lambda res: self.failUnlessEqual(res,
394                                                            "line1line2line3"))
395             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
396
397             return d
398         d.addCallback(_created)
399         return d
400
401     def test_modify_backoffer(self):
402         def _modifier(old_contents):
403             return old_contents + "line2"
404         calls = []
405         def _ucw_error_modifier(old_contents):
406             # simulate an UncoordinatedWriteError once
407             calls.append(1)
408             if len(calls) <= 1:
409                 raise UncoordinatedWriteError("simulated")
410             return old_contents + "line3"
411         def _always_ucw_error_modifier(old_contents):
412             raise UncoordinatedWriteError("simulated")
413         def _backoff_stopper(node, f):
414             return f
415         def _backoff_pauser(node, f):
416             d = defer.Deferred()
417             reactor.callLater(0.5, d.callback, None)
418             return d
419
420         # the give-up-er will hit its maximum retry count quickly
421         giveuper = BackoffAgent()
422         giveuper._delay = 0.1
423         giveuper.factor = 1
424
425         d = self.client.create_mutable_file("line1")
426         def _created(n):
427             d = n.modify(_modifier)
428             d.addCallback(lambda res: n.download_best_version())
429             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
430             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
431
432             d.addCallback(lambda res:
433                           self.shouldFail(UncoordinatedWriteError,
434                                           "_backoff_stopper", None,
435                                           n.modify, _ucw_error_modifier,
436                                           _backoff_stopper))
437             d.addCallback(lambda res: n.download_best_version())
438             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
439             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
440
441             def _reset_ucw_error_modifier(res):
442                 calls[:] = []
443                 return res
444             d.addCallback(_reset_ucw_error_modifier)
445             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
446                                                _backoff_pauser))
447             d.addCallback(lambda res: n.download_best_version())
448             d.addCallback(lambda res: self.failUnlessEqual(res,
449                                                            "line1line2line3"))
450             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
451
452             d.addCallback(lambda res:
453                           self.shouldFail(UncoordinatedWriteError,
454                                           "giveuper", None,
455                                           n.modify, _always_ucw_error_modifier,
456                                           giveuper.delay))
457             d.addCallback(lambda res: n.download_best_version())
458             d.addCallback(lambda res: self.failUnlessEqual(res,
459                                                            "line1line2line3"))
460             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
461
462             return d
463         d.addCallback(_created)
464         return d
465
466     def test_upload_and_download_full_size_keys(self):
467         self.client.mutable_file_node_class = MutableFileNode
468         d = self.client.create_mutable_file()
469         def _created(n):
470             d = defer.succeed(None)
471             d.addCallback(lambda res: n.get_servermap(MODE_READ))
472             d.addCallback(lambda smap: smap.dump(StringIO()))
473             d.addCallback(lambda sio:
474                           self.failUnless("3-of-10" in sio.getvalue()))
475             d.addCallback(lambda res: n.overwrite("contents 1"))
476             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
477             d.addCallback(lambda res: n.download_best_version())
478             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
479             d.addCallback(lambda res: n.overwrite("contents 2"))
480             d.addCallback(lambda res: n.download_best_version())
481             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
482             d.addCallback(lambda res: n.download(download.Data()))
483             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
484             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
485             d.addCallback(lambda smap: n.upload("contents 3", smap))
486             d.addCallback(lambda res: n.download_best_version())
487             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
488             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
489             d.addCallback(lambda smap:
490                           n.download_version(smap,
491                                              smap.best_recoverable_version()))
492             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
493             return d
494         d.addCallback(_created)
495         return d
496
497
498 class MakeShares(unittest.TestCase):
499     def test_encrypt(self):
500         c = FakeClient()
501         fn = FastMutableFileNode(c)
502         CONTENTS = "some initial contents"
503         d = fn.create(CONTENTS)
504         def _created(res):
505             p = Publish(fn, None)
506             p.salt = "SALT" * 4
507             p.readkey = "\x00" * 16
508             p.newdata = CONTENTS
509             p.required_shares = 3
510             p.total_shares = 10
511             p.setup_encoding_parameters()
512             return p._encrypt_and_encode()
513         d.addCallback(_created)
514         def _done(shares_and_shareids):
515             (shares, share_ids) = shares_and_shareids
516             self.failUnlessEqual(len(shares), 10)
517             for sh in shares:
518                 self.failUnless(isinstance(sh, str))
519                 self.failUnlessEqual(len(sh), 7)
520             self.failUnlessEqual(len(share_ids), 10)
521         d.addCallback(_done)
522         return d
523
524     def test_generate(self):
525         c = FakeClient()
526         fn = FastMutableFileNode(c)
527         CONTENTS = "some initial contents"
528         d = fn.create(CONTENTS)
529         def _created(res):
530             p = Publish(fn, None)
531             self._p = p
532             p.newdata = CONTENTS
533             p.required_shares = 3
534             p.total_shares = 10
535             p.setup_encoding_parameters()
536             p._new_seqnum = 3
537             p.salt = "SALT" * 4
538             # make some fake shares
539             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
540             p._privkey = fn.get_privkey()
541             p._encprivkey = fn.get_encprivkey()
542             p._pubkey = fn.get_pubkey()
543             return p._generate_shares(shares_and_ids)
544         d.addCallback(_created)
545         def _generated(res):
546             p = self._p
547             final_shares = p.shares
548             root_hash = p.root_hash
549             self.failUnlessEqual(len(root_hash), 32)
550             self.failUnless(isinstance(final_shares, dict))
551             self.failUnlessEqual(len(final_shares), 10)
552             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
553             for i,sh in final_shares.items():
554                 self.failUnless(isinstance(sh, str))
555                 # feed the share through the unpacker as a sanity-check
556                 pieces = unpack_share(sh)
557                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
558                  pubkey, signature, share_hash_chain, block_hash_tree,
559                  share_data, enc_privkey) = pieces
560                 self.failUnlessEqual(u_seqnum, 3)
561                 self.failUnlessEqual(u_root_hash, root_hash)
562                 self.failUnlessEqual(k, 3)
563                 self.failUnlessEqual(N, 10)
564                 self.failUnlessEqual(segsize, 21)
565                 self.failUnlessEqual(datalen, len(CONTENTS))
566                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
567                 sig_material = struct.pack(">BQ32s16s BBQQ",
568                                            0, p._new_seqnum, root_hash, IV,
569                                            k, N, segsize, datalen)
570                 self.failUnless(p._pubkey.verify(sig_material, signature))
571                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
572                 self.failUnless(isinstance(share_hash_chain, dict))
573                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
574                 for shnum,share_hash in share_hash_chain.items():
575                     self.failUnless(isinstance(shnum, int))
576                     self.failUnless(isinstance(share_hash, str))
577                     self.failUnlessEqual(len(share_hash), 32)
578                 self.failUnless(isinstance(block_hash_tree, list))
579                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
580                 self.failUnlessEqual(IV, "SALT"*4)
581                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
582                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
583         d.addCallback(_generated)
584         return d
585
586     # TODO: when we publish to 20 peers, we should get one share per peer on 10
587     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
588     # when we publish to zero peers, we should get a NotEnoughSharesError
589
590 class Servermap(unittest.TestCase):
591     def setUp(self):
592         # publish a file and create shares, which can then be manipulated
593         # later.
594         num_peers = 20
595         self._client = FakeClient(num_peers)
596         self._storage = self._client._storage
597         d = self._client.create_mutable_file("New contents go here")
598         def _created(node):
599             self._fn = node
600             self._fn2 = self._client.create_node_from_uri(node.get_uri())
601         d.addCallback(_created)
602         return d
603
604     def make_servermap(self, mode=MODE_CHECK, fn=None):
605         if fn is None:
606             fn = self._fn
607         smu = ServermapUpdater(fn, ServerMap(), mode)
608         d = smu.update()
609         return d
610
611     def update_servermap(self, oldmap, mode=MODE_CHECK):
612         smu = ServermapUpdater(self._fn, oldmap, mode)
613         d = smu.update()
614         return d
615
616     def failUnlessOneRecoverable(self, sm, num_shares):
617         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
618         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
619         best = sm.best_recoverable_version()
620         self.failIfEqual(best, None)
621         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
622         self.failUnlessEqual(len(sm.shares_available()), 1)
623         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3))
624         shnum, peerids = sm.make_sharemap().items()[0]
625         peerid = list(peerids)[0]
626         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
627         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
628         return sm
629
630     def test_basic(self):
631         d = defer.succeed(None)
632         ms = self.make_servermap
633         us = self.update_servermap
634
635         d.addCallback(lambda res: ms(mode=MODE_CHECK))
636         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
637         d.addCallback(lambda res: ms(mode=MODE_WRITE))
638         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
639         d.addCallback(lambda res: ms(mode=MODE_READ))
640         # this more stops at k+epsilon, and epsilon=k, so 6 shares
641         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
642         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
643         # this mode stops at 'k' shares
644         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
645
646         # and can we re-use the same servermap? Note that these are sorted in
647         # increasing order of number of servers queried, since once a server
648         # gets into the servermap, we'll always ask it for an update.
649         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
650         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
651         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
652         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
653         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
654         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
655         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
656         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
657         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
658
659         return d
660
661     def test_fetch_privkey(self):
662         d = defer.succeed(None)
663         # use the sibling filenode (which hasn't been used yet), and make
664         # sure it can fetch the privkey. The file is small, so the privkey
665         # will be fetched on the first (query) pass.
666         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
667         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
668
669         # create a new file, which is large enough to knock the privkey out
670         # of the early part of the file
671         LARGE = "These are Larger contents" * 200 # about 5KB
672         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
673         def _created(large_fn):
674             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
675             return self.make_servermap(MODE_WRITE, large_fn2)
676         d.addCallback(_created)
677         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
678         return d
679
680     def test_mark_bad(self):
681         d = defer.succeed(None)
682         ms = self.make_servermap
683         us = self.update_servermap
684
685         d.addCallback(lambda res: ms(mode=MODE_READ))
686         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
687         def _made_map(sm):
688             v = sm.best_recoverable_version()
689             vm = sm.make_versionmap()
690             shares = list(vm[v])
691             self.failUnlessEqual(len(shares), 6)
692             self._corrupted = set()
693             # mark the first 5 shares as corrupt, then update the servermap.
694             # The map should not have the marked shares it in any more, and
695             # new shares should be found to replace the missing ones.
696             for (shnum, peerid, timestamp) in shares:
697                 if shnum < 5:
698                     self._corrupted.add( (peerid, shnum) )
699                     sm.mark_bad_share(peerid, shnum)
700             return self.update_servermap(sm, MODE_WRITE)
701         d.addCallback(_made_map)
702         def _check_map(sm):
703             # this should find all 5 shares that weren't marked bad
704             v = sm.best_recoverable_version()
705             vm = sm.make_versionmap()
706             shares = list(vm[v])
707             for (peerid, shnum) in self._corrupted:
708                 peer_shares = sm.shares_on_peer(peerid)
709                 self.failIf(shnum in peer_shares,
710                             "%d was in %s" % (shnum, peer_shares))
711             self.failUnlessEqual(len(shares), 5)
712         d.addCallback(_check_map)
713         return d
714
715     def failUnlessNoneRecoverable(self, sm):
716         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
717         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
718         best = sm.best_recoverable_version()
719         self.failUnlessEqual(best, None)
720         self.failUnlessEqual(len(sm.shares_available()), 0)
721
722     def test_no_shares(self):
723         self._client._storage._peers = {} # delete all shares
724         ms = self.make_servermap
725         d = defer.succeed(None)
726
727         d.addCallback(lambda res: ms(mode=MODE_CHECK))
728         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
729
730         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
731         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
732
733         d.addCallback(lambda res: ms(mode=MODE_WRITE))
734         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
735
736         d.addCallback(lambda res: ms(mode=MODE_READ))
737         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
738
739         return d
740
741     def failUnlessNotQuiteEnough(self, sm):
742         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
743         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
744         best = sm.best_recoverable_version()
745         self.failUnlessEqual(best, None)
746         self.failUnlessEqual(len(sm.shares_available()), 1)
747         self.failUnlessEqual(sm.shares_available().values()[0], (2,3) )
748         return sm
749
750     def test_not_quite_enough_shares(self):
751         s = self._client._storage
752         ms = self.make_servermap
753         num_shares = len(s._peers)
754         for peerid in s._peers:
755             s._peers[peerid] = {}
756             num_shares -= 1
757             if num_shares == 2:
758                 break
759         # now there ought to be only two shares left
760         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
761
762         d = defer.succeed(None)
763
764         d.addCallback(lambda res: ms(mode=MODE_CHECK))
765         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
766         d.addCallback(lambda sm:
767                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
768         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
769         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
770         d.addCallback(lambda res: ms(mode=MODE_WRITE))
771         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
772         d.addCallback(lambda res: ms(mode=MODE_READ))
773         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
774
775         return d
776
777
778
779 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
780     def setUp(self):
781         # publish a file and create shares, which can then be manipulated
782         # later.
783         self.CONTENTS = "New contents go here"
784         num_peers = 20
785         self._client = FakeClient(num_peers)
786         self._storage = self._client._storage
787         d = self._client.create_mutable_file(self.CONTENTS)
788         def _created(node):
789             self._fn = node
790         d.addCallback(_created)
791         return d
792
793     def make_servermap(self, mode=MODE_READ, oldmap=None):
794         if oldmap is None:
795             oldmap = ServerMap()
796         smu = ServermapUpdater(self._fn, oldmap, mode)
797         d = smu.update()
798         return d
799
800     def abbrev_verinfo(self, verinfo):
801         if verinfo is None:
802             return None
803         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
804          offsets_tuple) = verinfo
805         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
806
807     def abbrev_verinfo_dict(self, verinfo_d):
808         output = {}
809         for verinfo,value in verinfo_d.items():
810             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
811              offsets_tuple) = verinfo
812             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
813         return output
814
815     def dump_servermap(self, servermap):
816         print "SERVERMAP", servermap
817         print "RECOVERABLE", [self.abbrev_verinfo(v)
818                               for v in servermap.recoverable_versions()]
819         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
820         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
821
822     def do_download(self, servermap, version=None):
823         if version is None:
824             version = servermap.best_recoverable_version()
825         r = Retrieve(self._fn, servermap, version)
826         return r.download()
827
828     def test_basic(self):
829         d = self.make_servermap()
830         def _do_retrieve(servermap):
831             self._smap = servermap
832             #self.dump_servermap(servermap)
833             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
834             return self.do_download(servermap)
835         d.addCallback(_do_retrieve)
836         def _retrieved(new_contents):
837             self.failUnlessEqual(new_contents, self.CONTENTS)
838         d.addCallback(_retrieved)
839         # we should be able to re-use the same servermap, both with and
840         # without updating it.
841         d.addCallback(lambda res: self.do_download(self._smap))
842         d.addCallback(_retrieved)
843         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
844         d.addCallback(lambda res: self.do_download(self._smap))
845         d.addCallback(_retrieved)
846         # clobbering the pubkey should make the servermap updater re-fetch it
847         def _clobber_pubkey(res):
848             self._fn._pubkey = None
849         d.addCallback(_clobber_pubkey)
850         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
851         d.addCallback(lambda res: self.do_download(self._smap))
852         d.addCallback(_retrieved)
853         return d
854
855
856     def _test_corrupt_all(self, offset, substring,
857                           should_succeed=False, corrupt_early=True):
858         d = defer.succeed(None)
859         if corrupt_early:
860             d.addCallback(corrupt, self._storage, offset)
861         d.addCallback(lambda res: self.make_servermap())
862         if not corrupt_early:
863             d.addCallback(corrupt, self._storage, offset)
864         def _do_retrieve(servermap):
865             ver = servermap.best_recoverable_version()
866             if ver is None and not should_succeed:
867                 # no recoverable versions == not succeeding. The problem
868                 # should be noted in the servermap's list of problems.
869                 if substring:
870                     allproblems = [str(f) for f in servermap.problems]
871                     self.failUnless(substring in "".join(allproblems))
872                 return servermap
873             if should_succeed:
874                 d1 = self._fn.download_best_version()
875                 d1.addCallback(lambda new_contents:
876                                self.failUnlessEqual(new_contents, self.CONTENTS))
877             else:
878                 d1 = self.shouldFail(NotEnoughSharesError,
879                                      "_corrupt_all(offset=%s)" % (offset,),
880                                      substring,
881                                      self._fn.download_best_version)
882             d1.addCallback(lambda res: servermap)
883             return d1
884         d.addCallback(_do_retrieve)
885         return d
886
887     def test_corrupt_all_verbyte(self):
888         # when the version byte is not 0, we hit an assertion error in
889         # unpack_share().
890         d = self._test_corrupt_all(0, "AssertionError")
891         def _check_servermap(servermap):
892             # and the dump should mention the problems
893             s = StringIO()
894             dump = servermap.dump(s).getvalue()
895             self.failUnless("10 PROBLEMS" in dump, dump)
896         d.addCallback(_check_servermap)
897         return d
898
899     def test_corrupt_all_seqnum(self):
900         # a corrupt sequence number will trigger a bad signature
901         return self._test_corrupt_all(1, "signature is invalid")
902
903     def test_corrupt_all_R(self):
904         # a corrupt root hash will trigger a bad signature
905         return self._test_corrupt_all(9, "signature is invalid")
906
907     def test_corrupt_all_IV(self):
908         # a corrupt salt/IV will trigger a bad signature
909         return self._test_corrupt_all(41, "signature is invalid")
910
911     def test_corrupt_all_k(self):
912         # a corrupt 'k' will trigger a bad signature
913         return self._test_corrupt_all(57, "signature is invalid")
914
915     def test_corrupt_all_N(self):
916         # a corrupt 'N' will trigger a bad signature
917         return self._test_corrupt_all(58, "signature is invalid")
918
919     def test_corrupt_all_segsize(self):
920         # a corrupt segsize will trigger a bad signature
921         return self._test_corrupt_all(59, "signature is invalid")
922
923     def test_corrupt_all_datalen(self):
924         # a corrupt data length will trigger a bad signature
925         return self._test_corrupt_all(67, "signature is invalid")
926
927     def test_corrupt_all_pubkey(self):
928         # a corrupt pubkey won't match the URI's fingerprint. We need to
929         # remove the pubkey from the filenode, or else it won't bother trying
930         # to update it.
931         self._fn._pubkey = None
932         return self._test_corrupt_all("pubkey",
933                                       "pubkey doesn't match fingerprint")
934
935     def test_corrupt_all_sig(self):
936         # a corrupt signature is a bad one
937         # the signature runs from about [543:799], depending upon the length
938         # of the pubkey
939         return self._test_corrupt_all("signature", "signature is invalid")
940
941     def test_corrupt_all_share_hash_chain_number(self):
942         # a corrupt share hash chain entry will show up as a bad hash. If we
943         # mangle the first byte, that will look like a bad hash number,
944         # causing an IndexError
945         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
946
947     def test_corrupt_all_share_hash_chain_hash(self):
948         # a corrupt share hash chain entry will show up as a bad hash. If we
949         # mangle a few bytes in, that will look like a bad hash.
950         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
951
952     def test_corrupt_all_block_hash_tree(self):
953         return self._test_corrupt_all("block_hash_tree",
954                                       "block hash tree failure")
955
956     def test_corrupt_all_block(self):
957         return self._test_corrupt_all("share_data", "block hash tree failure")
958
959     def test_corrupt_all_encprivkey(self):
960         # a corrupted privkey won't even be noticed by the reader, only by a
961         # writer.
962         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
963
964     def test_basic_pubkey_at_end(self):
965         # we corrupt the pubkey in all but the last 'k' shares, allowing the
966         # download to succeed but forcing a bunch of retries first. Note that
967         # this is rather pessimistic: our Retrieve process will throw away
968         # the whole share if the pubkey is bad, even though the rest of the
969         # share might be good.
970
971         self._fn._pubkey = None
972         k = self._fn.get_required_shares()
973         N = self._fn.get_total_shares()
974         d = defer.succeed(None)
975         d.addCallback(corrupt, self._storage, "pubkey",
976                       shnums_to_corrupt=range(0, N-k))
977         d.addCallback(lambda res: self.make_servermap())
978         def _do_retrieve(servermap):
979             self.failUnless(servermap.problems)
980             self.failUnless("pubkey doesn't match fingerprint"
981                             in str(servermap.problems[0]))
982             ver = servermap.best_recoverable_version()
983             r = Retrieve(self._fn, servermap, ver)
984             return r.download()
985         d.addCallback(_do_retrieve)
986         d.addCallback(lambda new_contents:
987                       self.failUnlessEqual(new_contents, self.CONTENTS))
988         return d
989
990     def test_corrupt_some(self):
991         # corrupt the data of first five shares (so the servermap thinks
992         # they're good but retrieve marks them as bad), so that the
993         # MODE_READ set of 6 will be insufficient, forcing node.download to
994         # retry with more servers.
995         corrupt(None, self._storage, "share_data", range(5))
996         d = self.make_servermap()
997         def _do_retrieve(servermap):
998             ver = servermap.best_recoverable_version()
999             self.failUnless(ver)
1000             return self._fn.download_best_version()
1001         d.addCallback(_do_retrieve)
1002         d.addCallback(lambda new_contents:
1003                       self.failUnlessEqual(new_contents, self.CONTENTS))
1004         return d
1005
1006     def test_download_fails(self):
1007         corrupt(None, self._storage, "signature")
1008         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
1009                             "no recoverable versions",
1010                             self._fn.download_best_version)
1011         return d
1012
1013
1014 class MultipleEncodings(unittest.TestCase):
1015     def setUp(self):
1016         self.CONTENTS = "New contents go here"
1017         num_peers = 20
1018         self._client = FakeClient(num_peers)
1019         self._storage = self._client._storage
1020         d = self._client.create_mutable_file(self.CONTENTS)
1021         def _created(node):
1022             self._fn = node
1023         d.addCallback(_created)
1024         return d
1025
1026     def _encode(self, k, n, data):
1027         # encode 'data' into a peerid->shares dict.
1028
1029         fn2 = FastMutableFileNode(self._client)
1030         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1031         # and _fingerprint
1032         fn = self._fn
1033         fn2.init_from_uri(fn.get_uri())
1034         # then we copy over other fields that are normally fetched from the
1035         # existing shares
1036         fn2._pubkey = fn._pubkey
1037         fn2._privkey = fn._privkey
1038         fn2._encprivkey = fn._encprivkey
1039         # and set the encoding parameters to something completely different
1040         fn2._required_shares = k
1041         fn2._total_shares = n
1042
1043         s = self._client._storage
1044         s._peers = {} # clear existing storage
1045         p2 = Publish(fn2, None)
1046         d = p2.publish(data)
1047         def _published(res):
1048             shares = s._peers
1049             s._peers = {}
1050             return shares
1051         d.addCallback(_published)
1052         return d
1053
1054     def make_servermap(self, mode=MODE_READ, oldmap=None):
1055         if oldmap is None:
1056             oldmap = ServerMap()
1057         smu = ServermapUpdater(self._fn, oldmap, mode)
1058         d = smu.update()
1059         return d
1060
1061     def test_multiple_encodings(self):
1062         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1063         # then mix up the shares, to make sure that download survives seeing
1064         # a variety of encodings. This is actually kind of tricky to set up.
1065
1066         contents1 = "Contents for encoding 1 (3-of-10) go here"
1067         contents2 = "Contents for encoding 2 (4-of-9) go here"
1068         contents3 = "Contents for encoding 3 (4-of-7) go here"
1069
1070         # we make a retrieval object that doesn't know what encoding
1071         # parameters to use
1072         fn3 = FastMutableFileNode(self._client)
1073         fn3.init_from_uri(self._fn.get_uri())
1074
1075         # now we upload a file through fn1, and grab its shares
1076         d = self._encode(3, 10, contents1)
1077         def _encoded_1(shares):
1078             self._shares1 = shares
1079         d.addCallback(_encoded_1)
1080         d.addCallback(lambda res: self._encode(4, 9, contents2))
1081         def _encoded_2(shares):
1082             self._shares2 = shares
1083         d.addCallback(_encoded_2)
1084         d.addCallback(lambda res: self._encode(4, 7, contents3))
1085         def _encoded_3(shares):
1086             self._shares3 = shares
1087         d.addCallback(_encoded_3)
1088
1089         def _merge(res):
1090             log.msg("merging sharelists")
1091             # we merge the shares from the two sets, leaving each shnum in
1092             # its original location, but using a share from set1 or set2
1093             # according to the following sequence:
1094             #
1095             #  4-of-9  a  s2
1096             #  4-of-9  b  s2
1097             #  4-of-7  c   s3
1098             #  4-of-9  d  s2
1099             #  3-of-9  e s1
1100             #  3-of-9  f s1
1101             #  3-of-9  g s1
1102             #  4-of-9  h  s2
1103             #
1104             # so that neither form can be recovered until fetch [f], at which
1105             # point version-s1 (the 3-of-10 form) should be recoverable. If
1106             # the implementation latches on to the first version it sees,
1107             # then s2 will be recoverable at fetch [g].
1108
1109             # Later, when we implement code that handles multiple versions,
1110             # we can use this framework to assert that all recoverable
1111             # versions are retrieved, and test that 'epsilon' does its job
1112
1113             places = [2, 2, 3, 2, 1, 1, 1, 2]
1114
1115             sharemap = {}
1116
1117             for i,peerid in enumerate(self._client._peerids):
1118                 peerid_s = shortnodeid_b2a(peerid)
1119                 for shnum in self._shares1.get(peerid, {}):
1120                     if shnum < len(places):
1121                         which = places[shnum]
1122                     else:
1123                         which = "x"
1124                     self._client._storage._peers[peerid] = peers = {}
1125                     in_1 = shnum in self._shares1[peerid]
1126                     in_2 = shnum in self._shares2.get(peerid, {})
1127                     in_3 = shnum in self._shares3.get(peerid, {})
1128                     #print peerid_s, shnum, which, in_1, in_2, in_3
1129                     if which == 1:
1130                         if in_1:
1131                             peers[shnum] = self._shares1[peerid][shnum]
1132                             sharemap[shnum] = peerid
1133                     elif which == 2:
1134                         if in_2:
1135                             peers[shnum] = self._shares2[peerid][shnum]
1136                             sharemap[shnum] = peerid
1137                     elif which == 3:
1138                         if in_3:
1139                             peers[shnum] = self._shares3[peerid][shnum]
1140                             sharemap[shnum] = peerid
1141
1142             # we don't bother placing any other shares
1143             # now sort the sequence so that share 0 is returned first
1144             new_sequence = [sharemap[shnum]
1145                             for shnum in sorted(sharemap.keys())]
1146             self._client._storage._sequence = new_sequence
1147             log.msg("merge done")
1148         d.addCallback(_merge)
1149         d.addCallback(lambda res: fn3.download_best_version())
1150         def _retrieved(new_contents):
1151             # the current specified behavior is "first version recoverable"
1152             self.failUnlessEqual(new_contents, contents1)
1153         d.addCallback(_retrieved)
1154         return d
1155
1156 class MultipleVersions(unittest.TestCase):
1157     def setUp(self):
1158         self.CONTENTS = ["Contents 0",
1159                          "Contents 1",
1160                          "Contents 2",
1161                          "Contents 3a",
1162                          "Contents 3b"]
1163         self._copied_shares = {}
1164         num_peers = 20
1165         self._client = FakeClient(num_peers)
1166         self._storage = self._client._storage
1167         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
1168         def _created(node):
1169             self._fn = node
1170             # now create multiple versions of the same file, and accumulate
1171             # their shares, so we can mix and match them later.
1172             d = defer.succeed(None)
1173             d.addCallback(self._copy_shares, 0)
1174             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
1175             d.addCallback(self._copy_shares, 1)
1176             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
1177             d.addCallback(self._copy_shares, 2)
1178             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
1179             d.addCallback(self._copy_shares, 3)
1180             # now we replace all the shares with version s3, and upload a new
1181             # version to get s4b.
1182             rollback = dict([(i,2) for i in range(10)])
1183             d.addCallback(lambda res: self._set_versions(rollback))
1184             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
1185             d.addCallback(self._copy_shares, 4)
1186             # we leave the storage in state 4
1187             return d
1188         d.addCallback(_created)
1189         return d
1190
1191     def _copy_shares(self, ignored, index):
1192         shares = self._client._storage._peers
1193         # we need a deep copy
1194         new_shares = {}
1195         for peerid in shares:
1196             new_shares[peerid] = {}
1197             for shnum in shares[peerid]:
1198                 new_shares[peerid][shnum] = shares[peerid][shnum]
1199         self._copied_shares[index] = new_shares
1200
1201     def _set_versions(self, versionmap):
1202         # versionmap maps shnums to which version (0,1,2,3,4) we want the
1203         # share to be at. Any shnum which is left out of the map will stay at
1204         # its current version.
1205         shares = self._client._storage._peers
1206         oldshares = self._copied_shares
1207         for peerid in shares:
1208             for shnum in shares[peerid]:
1209                 if shnum in versionmap:
1210                     index = versionmap[shnum]
1211                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
1212
1213     def test_multiple_versions(self):
1214         # if we see a mix of versions in the grid, download_best_version
1215         # should get the latest one
1216         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1217         d = self._fn.download_best_version()
1218         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1219         # but if everything is at version 2, that's what we should download
1220         d.addCallback(lambda res:
1221                       self._set_versions(dict([(i,2) for i in range(10)])))
1222         d.addCallback(lambda res: self._fn.download_best_version())
1223         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1224         # if exactly one share is at version 3, we should still get v2
1225         d.addCallback(lambda res:
1226                       self._set_versions({0:3}))
1227         d.addCallback(lambda res: self._fn.download_best_version())
1228         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1229         # but the servermap should see the unrecoverable version. This
1230         # depends upon the single newer share being queried early.
1231         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1232         def _check_smap(smap):
1233             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1234             newer = smap.unrecoverable_newer_versions()
1235             self.failUnlessEqual(len(newer), 1)
1236             verinfo, health = newer.items()[0]
1237             self.failUnlessEqual(verinfo[0], 4)
1238             self.failUnlessEqual(health, (1,3))
1239             self.failIf(smap.needs_merge())
1240         d.addCallback(_check_smap)
1241         # if we have a mix of two parallel versions (s4a and s4b), we could
1242         # recover either
1243         d.addCallback(lambda res:
1244                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1245                                           1:4,3:4,5:4,7:4,9:4}))
1246         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1247         def _check_smap_mixed(smap):
1248             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1249             newer = smap.unrecoverable_newer_versions()
1250             self.failUnlessEqual(len(newer), 0)
1251             self.failUnless(smap.needs_merge())
1252         d.addCallback(_check_smap_mixed)
1253         d.addCallback(lambda res: self._fn.download_best_version())
1254         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1255                                                   res == self.CONTENTS[4]))
1256         return d
1257
1258     def test_replace(self):
1259         # if we see a mix of versions in the grid, we should be able to
1260         # replace them all with a newer version
1261
1262         # if exactly one share is at version 3, we should download (and
1263         # replace) v2, and the result should be v4. Note that the index we
1264         # give to _set_versions is different than the sequence number.
1265         target = dict([(i,2) for i in range(10)]) # seqnum3
1266         target[0] = 3 # seqnum4
1267         self._set_versions(target)
1268
1269         def _modify(oldversion):
1270             return oldversion + " modified"
1271         d = self._fn.modify(_modify)
1272         d.addCallback(lambda res: self._fn.download_best_version())
1273         expected = self.CONTENTS[2] + " modified"
1274         d.addCallback(lambda res: self.failUnlessEqual(res, expected))
1275         # and the servermap should indicate that the outlier was replaced too
1276         d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
1277         def _check_smap(smap):
1278             self.failUnlessEqual(smap.highest_seqnum(), 5)
1279             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1280             self.failUnlessEqual(len(smap.recoverable_versions()), 1)
1281         d.addCallback(_check_smap)
1282         return d
1283
1284
1285 class Utils(unittest.TestCase):
1286     def test_dict_of_sets(self):
1287         ds = DictOfSets()
1288         ds.add(1, "a")
1289         ds.add(2, "b")
1290         ds.add(2, "b")
1291         ds.add(2, "c")
1292         self.failUnlessEqual(ds[1], set(["a"]))
1293         self.failUnlessEqual(ds[2], set(["b", "c"]))
1294         ds.discard(3, "d") # should not raise an exception
1295         ds.discard(2, "b")
1296         self.failUnlessEqual(ds[2], set(["c"]))
1297         ds.discard(2, "c")
1298         self.failIf(2 in ds)
1299
1300     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1301         # we compare this against sets of integers
1302         x = set(range(x_start, x_start+x_length))
1303         y = set(range(y_start, y_start+y_length))
1304         should_be_inside = x.issubset(y)
1305         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1306                                                          y_start, y_length),
1307                              str((x_start, x_length, y_start, y_length)))
1308
1309     def test_cache_inside(self):
1310         c = ResponseCache()
1311         x_start = 10
1312         x_length = 5
1313         for y_start in range(8, 17):
1314             for y_length in range(8):
1315                 self._do_inside(c, x_start, x_length, y_start, y_length)
1316
1317     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1318         # we compare this against sets of integers
1319         x = set(range(x_start, x_start+x_length))
1320         y = set(range(y_start, y_start+y_length))
1321         overlap = bool(x.intersection(y))
1322         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1323                                                       y_start, y_length),
1324                              str((x_start, x_length, y_start, y_length)))
1325
1326     def test_cache_overlap(self):
1327         c = ResponseCache()
1328         x_start = 10
1329         x_length = 5
1330         for y_start in range(8, 17):
1331             for y_length in range(8):
1332                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1333
1334     def test_cache(self):
1335         c = ResponseCache()
1336         # xdata = base62.b2a(os.urandom(100))[:100]
1337         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1338         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1339         nope = (None, None)
1340         c.add("v1", 1, 0, xdata, "time0")
1341         c.add("v1", 1, 2000, ydata, "time1")
1342         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1343         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1344         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1345         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1346         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1347         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1348         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1349         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1350         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1351         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1352         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1353         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1354         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1355         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1356         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1357         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1358         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1359         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1360
1361         # optional: join fragments
1362         c = ResponseCache()
1363         c.add("v1", 1, 0, xdata[:10], "time0")
1364         c.add("v1", 1, 10, xdata[10:20], "time1")
1365         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1366
1367 class Exceptions(unittest.TestCase):
1368     def test_repr(self):
1369         nmde = NeedMoreDataError(100, 50, 100)
1370         self.failUnless("NeedMoreDataError" in repr(nmde), repr(nmde))
1371         ucwe = UncoordinatedWriteError()
1372         self.failUnless("UncoordinatedWriteError" in repr(ucwe), repr(ucwe))
1373
1374 # we can't do this test with a FakeClient, since it uses FakeStorageServer
1375 # instances which always succeed. So we need a less-fake one.
1376
1377 class IntentionalError(Exception):
1378     pass
1379
1380 class LocalWrapper:
1381     def __init__(self, original):
1382         self.original = original
1383         self.broken = False
1384         self.post_call_notifier = None
1385     def callRemote(self, methname, *args, **kwargs):
1386         def _call():
1387             if self.broken:
1388                 raise IntentionalError("I was asked to break")
1389             meth = getattr(self.original, "remote_" + methname)
1390             return meth(*args, **kwargs)
1391         d = fireEventually()
1392         d.addCallback(lambda res: _call())
1393         if self.post_call_notifier:
1394             d.addCallback(self.post_call_notifier, methname)
1395         return d
1396
1397 class LessFakeClient(FakeClient):
1398
1399     def __init__(self, basedir, num_peers=10):
1400         self._num_peers = num_peers
1401         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
1402                          for i in range(self._num_peers)]
1403         self._connections = {}
1404         for peerid in self._peerids:
1405             peerdir = os.path.join(basedir, idlib.shortnodeid_b2a(peerid))
1406             make_dirs(peerdir)
1407             ss = storage.StorageServer(peerdir)
1408             ss.setNodeID(peerid)
1409             lw = LocalWrapper(ss)
1410             self._connections[peerid] = lw
1411         self.nodeid = "fakenodeid"
1412
1413
1414 class Problems(unittest.TestCase, testutil.ShouldFailMixin):
1415     def test_publish_surprise(self):
1416         basedir = os.path.join("mutable/CollidingWrites/test_surprise")
1417         self.client = LessFakeClient(basedir)
1418         d = self.client.create_mutable_file("contents 1")
1419         def _created(n):
1420             d = defer.succeed(None)
1421             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1422             def _got_smap1(smap):
1423                 # stash the old state of the file
1424                 self.old_map = smap
1425             d.addCallback(_got_smap1)
1426             # then modify the file, leaving the old map untouched
1427             d.addCallback(lambda res: log.msg("starting winning write"))
1428             d.addCallback(lambda res: n.overwrite("contents 2"))
1429             # now attempt to modify the file with the old servermap. This
1430             # will look just like an uncoordinated write, in which every
1431             # single share got updated between our mapupdate and our publish
1432             d.addCallback(lambda res: log.msg("starting doomed write"))
1433             d.addCallback(lambda res:
1434                           self.shouldFail(UncoordinatedWriteError,
1435                                           "test_publish_surprise", None,
1436                                           n.upload,
1437                                           "contents 2a", self.old_map))
1438             return d
1439         d.addCallback(_created)
1440         return d
1441
1442     def test_retrieve_surprise(self):
1443         basedir = os.path.join("mutable/CollidingWrites/test_retrieve")
1444         self.client = LessFakeClient(basedir)
1445         d = self.client.create_mutable_file("contents 1")
1446         def _created(n):
1447             d = defer.succeed(None)
1448             d.addCallback(lambda res: n.get_servermap(MODE_READ))
1449             def _got_smap1(smap):
1450                 # stash the old state of the file
1451                 self.old_map = smap
1452             d.addCallback(_got_smap1)
1453             # then modify the file, leaving the old map untouched
1454             d.addCallback(lambda res: log.msg("starting winning write"))
1455             d.addCallback(lambda res: n.overwrite("contents 2"))
1456             # now attempt to retrieve the old version with the old servermap.
1457             # This will look like someone has changed the file since we
1458             # updated the servermap.
1459             d.addCallback(lambda res: n._cache._clear())
1460             d.addCallback(lambda res: log.msg("starting doomed read"))
1461             d.addCallback(lambda res:
1462                           self.shouldFail(NotEnoughSharesError,
1463                                           "test_retrieve_surprise",
1464                                           "ran out of peers: have 0 shares (k=3)",
1465                                           n.download_version,
1466                                           self.old_map,
1467                                           self.old_map.best_recoverable_version(),
1468                                           ))
1469             return d
1470         d.addCallback(_created)
1471         return d
1472
1473     def test_unexpected_shares(self):
1474         # upload the file, take a servermap, shut down one of the servers,
1475         # upload it again (causing shares to appear on a new server), then
1476         # upload using the old servermap. The last upload should fail with an
1477         # UncoordinatedWriteError, because of the shares that didn't appear
1478         # in the servermap.
1479         basedir = os.path.join("mutable/CollidingWrites/test_unexpexted_shares")
1480         self.client = LessFakeClient(basedir)
1481         d = self.client.create_mutable_file("contents 1")
1482         def _created(n):
1483             d = defer.succeed(None)
1484             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
1485             def _got_smap1(smap):
1486                 # stash the old state of the file
1487                 self.old_map = smap
1488                 # now shut down one of the servers
1489                 peer0 = list(smap.make_sharemap()[0])[0]
1490                 self.client._connections.pop(peer0)
1491                 # then modify the file, leaving the old map untouched
1492                 log.msg("starting winning write")
1493                 return n.overwrite("contents 2")
1494             d.addCallback(_got_smap1)
1495             # now attempt to modify the file with the old servermap. This
1496             # will look just like an uncoordinated write, in which every
1497             # single share got updated between our mapupdate and our publish
1498             d.addCallback(lambda res: log.msg("starting doomed write"))
1499             d.addCallback(lambda res:
1500                           self.shouldFail(UncoordinatedWriteError,
1501                                           "test_surprise", None,
1502                                           n.upload,
1503                                           "contents 2a", self.old_map))
1504             return d
1505         d.addCallback(_created)
1506         return d
1507
1508     def test_bad_server(self):
1509         # Break one server, then create the file: the initial publish should
1510         # complete with an alternate server. Breaking a second server should
1511         # not prevent an update from succeeding either.
1512         basedir = os.path.join("mutable/CollidingWrites/test_bad_server")
1513         self.client = LessFakeClient(basedir, 20)
1514         # to make sure that one of the initial peers is broken, we have to
1515         # get creative. We create the keys, so we can figure out the storage
1516         # index, but we hold off on doing the initial publish until we've
1517         # broken the server on which the first share wants to be stored.
1518         n = FastMutableFileNode(self.client)
1519         d = defer.succeed(None)
1520         d.addCallback(n._generate_pubprivkeys)
1521         d.addCallback(n._generated)
1522         def _break_peer0(res):
1523             si = n.get_storage_index()
1524             peerlist = self.client.get_permuted_peers("storage", si)
1525             peerid0, connection0 = peerlist[0]
1526             peerid1, connection1 = peerlist[1]
1527             connection0.broken = True
1528             self.connection1 = connection1
1529         d.addCallback(_break_peer0)
1530         # now let the initial publish finally happen
1531         d.addCallback(lambda res: n._upload("contents 1", None))
1532         # that ought to work
1533         d.addCallback(lambda res: n.download_best_version())
1534         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
1535         # now break the second peer
1536         def _break_peer1(res):
1537             self.connection1.broken = True
1538         d.addCallback(_break_peer1)
1539         d.addCallback(lambda res: n.overwrite("contents 2"))
1540         # that ought to work too
1541         d.addCallback(lambda res: n.download_best_version())
1542         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
1543         return d
1544
1545     def test_publish_all_servers_bad(self):
1546         # Break all servers: the publish should fail
1547         basedir = os.path.join("mutable/CollidingWrites/publish_all_servers_bad")
1548         self.client = LessFakeClient(basedir, 20)
1549         for connection in self.client._connections.values():
1550             connection.broken = True
1551         d = self.shouldFail(NotEnoughServersError,
1552                             "test_publish_all_servers_bad",
1553                             "Ran out of non-bad servers",
1554                             self.client.create_mutable_file, "contents")
1555         return d
1556
1557     def test_privkey_query_error(self):
1558         # when a servermap is updated with MODE_WRITE, it tries to get the
1559         # privkey. Something might go wrong during this query attempt.
1560         self.client = FakeClient(20)
1561         # we need some contents that are large enough to push the privkey out
1562         # of the early part of the file
1563         LARGE = "These are Larger contents" * 200 # about 5KB
1564         d = self.client.create_mutable_file(LARGE)
1565         def _created(n):
1566             self.uri = n.get_uri()
1567             self.n2 = self.client.create_node_from_uri(self.uri)
1568             # we start by doing a map update to figure out which is the first
1569             # server.
1570             return n.get_servermap(MODE_WRITE)
1571         d.addCallback(_created)
1572         d.addCallback(lambda res: fireEventually(res))
1573         def _got_smap1(smap):
1574             peer0 = list(smap.make_sharemap()[0])[0]
1575             # we tell the server to respond to this peer first, so that it
1576             # will be asked for the privkey first
1577             self.client._storage._sequence = [peer0]
1578             # now we make the peer fail their second query
1579             self.client._storage._special_answers[peer0] = ["normal", "fail"]
1580         d.addCallback(_got_smap1)
1581         # now we update a servermap from a new node (which doesn't have the
1582         # privkey yet, forcing it to use a separate privkey query). Each
1583         # query response will trigger a privkey query, and since we're using
1584         # _sequence to make the peer0 response come back first, we'll send it
1585         # a privkey query first, and _sequence will again ensure that the
1586         # peer0 query will also come back before the others, and then
1587         # _special_answers will make sure that the query raises an exception.
1588         # The whole point of these hijinks is to exercise the code in
1589         # _privkey_query_failed. Note that the map-update will succeed, since
1590         # we'll just get a copy from one of the other shares.
1591         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1592         # Using FakeStorage._sequence means there will be read requests still
1593         # floating around.. wait for them to retire
1594         def _cancel_timer(res):
1595             if self.client._storage._pending_timer:
1596                 self.client._storage._pending_timer.cancel()
1597             return res
1598         d.addBoth(_cancel_timer)
1599         return d
1600
1601     def test_privkey_query_missing(self):
1602         # like test_privkey_query_error, but the shares are deleted by the
1603         # second query, instead of raising an exception.
1604         self.client = FakeClient(20)
1605         LARGE = "These are Larger contents" * 200 # about 5KB
1606         d = self.client.create_mutable_file(LARGE)
1607         def _created(n):
1608             self.uri = n.get_uri()
1609             self.n2 = self.client.create_node_from_uri(self.uri)
1610             return n.get_servermap(MODE_WRITE)
1611         d.addCallback(_created)
1612         d.addCallback(lambda res: fireEventually(res))
1613         def _got_smap1(smap):
1614             peer0 = list(smap.make_sharemap()[0])[0]
1615             self.client._storage._sequence = [peer0]
1616             self.client._storage._special_answers[peer0] = ["normal", "none"]
1617         d.addCallback(_got_smap1)
1618         d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
1619         def _cancel_timer(res):
1620             if self.client._storage._pending_timer:
1621                 self.client._storage._pending_timer.cancel()
1622             return res
1623         d.addBoth(_cancel_timer)
1624         return d