]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
mutable/servermap: improve test coverage
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from allmydata import uri, download
7 from allmydata.util import base32, testutil
8 from allmydata.util.idlib import shortnodeid_b2a
9 from allmydata.util.hashutil import tagged_hash
10 from allmydata.encode import NotEnoughSharesError
11 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable
12 from foolscap.eventual import eventually, fireEventually
13 from foolscap.logging import log
14 import sha
15
16 from allmydata.mutable.node import MutableFileNode, BackoffAgent
17 from allmydata.mutable.common import DictOfSets, ResponseCache, \
18      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
19      UnrecoverableFileError, UncoordinatedWriteError
20 from allmydata.mutable.retrieve import Retrieve
21 from allmydata.mutable.publish import Publish
22 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
23 from allmydata.mutable.layout import unpack_header, unpack_share
24
25 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
26 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
27 # of this.
28
29 class FastMutableFileNode(MutableFileNode):
30     SIGNATURE_KEY_SIZE = 522
31
32 # this "FakeStorage" exists to put the share data in RAM and avoid using real
33 # network connections, both to speed up the tests and to reduce the amount of
34 # non-mutable.py code being exercised.
35
36 class FakeStorage:
37     # this class replaces the collection of storage servers, allowing the
38     # tests to examine and manipulate the published shares. It also lets us
39     # control the order in which read queries are answered, to exercise more
40     # of the error-handling code in Retrieve .
41     #
42     # Note that we ignore the storage index: this FakeStorage instance can
43     # only be used for a single storage index.
44
45
46     def __init__(self):
47         self._peers = {}
48         # _sequence is used to cause the responses to occur in a specific
49         # order. If it is in use, then we will defer queries instead of
50         # answering them right away, accumulating the Deferreds in a dict. We
51         # don't know exactly how many queries we'll get, so exactly one
52         # second after the first query arrives, we will release them all (in
53         # order).
54         self._sequence = None
55         self._pending = {}
56
57     def read(self, peerid, storage_index):
58         shares = self._peers.get(peerid, {})
59         if self._sequence is None:
60             return defer.succeed(shares)
61         d = defer.Deferred()
62         if not self._pending:
63             reactor.callLater(1.0, self._fire_readers)
64         self._pending[peerid] = (d, shares)
65         return d
66
67     def _fire_readers(self):
68         pending = self._pending
69         self._pending = {}
70         extra = []
71         for peerid in self._sequence:
72             if peerid in pending:
73                 d, shares = pending.pop(peerid)
74                 eventually(d.callback, shares)
75         for (d, shares) in pending.values():
76             eventually(d.callback, shares)
77
78     def write(self, peerid, storage_index, shnum, offset, data):
79         if peerid not in self._peers:
80             self._peers[peerid] = {}
81         shares = self._peers[peerid]
82         f = StringIO()
83         f.write(shares.get(shnum, ""))
84         f.seek(offset)
85         f.write(data)
86         shares[shnum] = f.getvalue()
87
88
89 class FakeStorageServer:
90     def __init__(self, peerid, storage):
91         self.peerid = peerid
92         self.storage = storage
93         self.queries = 0
94     def callRemote(self, methname, *args, **kwargs):
95         def _call():
96             meth = getattr(self, methname)
97             return meth(*args, **kwargs)
98         d = fireEventually()
99         d.addCallback(lambda res: _call())
100         return d
101
102     def slot_readv(self, storage_index, shnums, readv):
103         d = self.storage.read(self.peerid, storage_index)
104         def _read(shares):
105             response = {}
106             for shnum in shares:
107                 if shnums and shnum not in shnums:
108                     continue
109                 vector = response[shnum] = []
110                 for (offset, length) in readv:
111                     assert isinstance(offset, (int, long)), offset
112                     assert isinstance(length, (int, long)), length
113                     vector.append(shares[shnum][offset:offset+length])
114             return response
115         d.addCallback(_read)
116         return d
117
118     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
119                                         tw_vectors, read_vector):
120         # always-pass: parrot the test vectors back to them.
121         readv = {}
122         for shnum, (testv, writev, new_length) in tw_vectors.items():
123             for (offset, length, op, specimen) in testv:
124                 assert op in ("le", "eq", "ge")
125             # TODO: this isn't right, the read is controlled by read_vector,
126             # not by testv
127             readv[shnum] = [ specimen
128                              for (offset, length, op, specimen)
129                              in testv ]
130             for (offset, data) in writev:
131                 self.storage.write(self.peerid, storage_index, shnum,
132                                    offset, data)
133         answer = (True, readv)
134         return fireEventually(answer)
135
136
137 # our "FakeClient" has just enough functionality of the real Client to let
138 # the tests run.
139
140 class FakeClient:
141     mutable_file_node_class = FastMutableFileNode
142
143     def __init__(self, num_peers=10):
144         self._storage = FakeStorage()
145         self._num_peers = num_peers
146         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
147                          for i in range(self._num_peers)]
148         self._connections = dict([(peerid, FakeStorageServer(peerid,
149                                                              self._storage))
150                                   for peerid in self._peerids])
151         self.nodeid = "fakenodeid"
152
153     def log(self, msg, **kw):
154         return log.msg(msg, **kw)
155
156     def get_renewal_secret(self):
157         return "I hereby permit you to renew my files"
158     def get_cancel_secret(self):
159         return "I hereby permit you to cancel my leases"
160
161     def create_mutable_file(self, contents=""):
162         n = self.mutable_file_node_class(self)
163         d = n.create(contents)
164         d.addCallback(lambda res: n)
165         return d
166
167     def notify_retrieve(self, r):
168         pass
169     def notify_publish(self, p):
170         pass
171     def notify_mapupdate(self, u):
172         pass
173
174     def create_node_from_uri(self, u):
175         u = IURI(u)
176         assert IMutableFileURI.providedBy(u), u
177         res = self.mutable_file_node_class(self).init_from_uri(u)
178         return res
179
180     def get_permuted_peers(self, service_name, key):
181         """
182         @return: list of (peerid, connection,)
183         """
184         results = []
185         for (peerid, connection) in self._connections.items():
186             assert isinstance(peerid, str)
187             permuted = sha.new(key + peerid).digest()
188             results.append((permuted, peerid, connection))
189         results.sort()
190         results = [ (r[1],r[2]) for r in results]
191         return results
192
193     def upload(self, uploadable):
194         assert IUploadable.providedBy(uploadable)
195         d = uploadable.get_size()
196         d.addCallback(lambda length: uploadable.read(length))
197         #d.addCallback(self.create_mutable_file)
198         def _got_data(datav):
199             data = "".join(datav)
200             #newnode = FastMutableFileNode(self)
201             return uri.LiteralFileURI(data)
202         d.addCallback(_got_data)
203         return d
204
205
206 def flip_bit(original, byte_offset):
207     return (original[:byte_offset] +
208             chr(ord(original[byte_offset]) ^ 0x01) +
209             original[byte_offset+1:])
210
211 def corrupt(res, s, offset, shnums_to_corrupt=None):
212     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
213     # list of shnums to corrupt.
214     for peerid in s._peers:
215         shares = s._peers[peerid]
216         for shnum in shares:
217             if (shnums_to_corrupt is not None
218                 and shnum not in shnums_to_corrupt):
219                 continue
220             data = shares[shnum]
221             (version,
222              seqnum,
223              root_hash,
224              IV,
225              k, N, segsize, datalen,
226              o) = unpack_header(data)
227             if isinstance(offset, tuple):
228                 offset1, offset2 = offset
229             else:
230                 offset1 = offset
231                 offset2 = 0
232             if offset1 == "pubkey":
233                 real_offset = 107
234             elif offset1 in o:
235                 real_offset = o[offset1]
236             else:
237                 real_offset = offset1
238             real_offset = int(real_offset) + offset2
239             assert isinstance(real_offset, int), offset
240             shares[shnum] = flip_bit(data, real_offset)
241     return res
242
243 class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
244     def setUp(self):
245         self.client = FakeClient()
246
247     def test_create(self):
248         d = self.client.create_mutable_file()
249         def _created(n):
250             self.failUnless(isinstance(n, FastMutableFileNode))
251             peer0 = self.client._peerids[0]
252             shnums = self.client._storage._peers[peer0].keys()
253             self.failUnlessEqual(len(shnums), 1)
254         d.addCallback(_created)
255         return d
256
257     def test_serialize(self):
258         n = MutableFileNode(self.client)
259         calls = []
260         def _callback(*args, **kwargs):
261             self.failUnlessEqual(args, (4,) )
262             self.failUnlessEqual(kwargs, {"foo": 5})
263             calls.append(1)
264             return 6
265         d = n._do_serialized(_callback, 4, foo=5)
266         def _check_callback(res):
267             self.failUnlessEqual(res, 6)
268             self.failUnlessEqual(calls, [1])
269         d.addCallback(_check_callback)
270
271         def _errback():
272             raise ValueError("heya")
273         d.addCallback(lambda res:
274                       self.shouldFail(ValueError, "_check_errback", "heya",
275                                       n._do_serialized, _errback))
276         return d
277
278     def test_upload_and_download(self):
279         d = self.client.create_mutable_file()
280         def _created(n):
281             d = defer.succeed(None)
282             d.addCallback(lambda res: n.get_servermap(MODE_READ))
283             d.addCallback(lambda smap: smap.dump(StringIO()))
284             d.addCallback(lambda sio:
285                           self.failUnless("3-of-10" in sio.getvalue()))
286             d.addCallback(lambda res: n.overwrite("contents 1"))
287             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
288             d.addCallback(lambda res: n.download_best_version())
289             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
290             d.addCallback(lambda res: n.overwrite("contents 2"))
291             d.addCallback(lambda res: n.download_best_version())
292             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
293             d.addCallback(lambda res: n.download(download.Data()))
294             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
295             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
296             d.addCallback(lambda smap: n.upload("contents 3", smap))
297             d.addCallback(lambda res: n.download_best_version())
298             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
299             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
300             d.addCallback(lambda smap:
301                           n.download_version(smap,
302                                              smap.best_recoverable_version()))
303             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
304             return d
305         d.addCallback(_created)
306         return d
307
308     def test_create_with_initial_contents(self):
309         d = self.client.create_mutable_file("contents 1")
310         def _created(n):
311             d = n.download_best_version()
312             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
313             d.addCallback(lambda res: n.overwrite("contents 2"))
314             d.addCallback(lambda res: n.download_best_version())
315             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
316             return d
317         d.addCallback(_created)
318         return d
319
320     def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
321         d = n.get_servermap(MODE_READ)
322         d.addCallback(lambda servermap: servermap.best_recoverable_version())
323         d.addCallback(lambda verinfo:
324                       self.failUnlessEqual(verinfo[0], expected_seqnum))
325         return d
326
327     def test_modify(self):
328         def _modifier(old_contents):
329             return old_contents + "line2"
330         def _non_modifier(old_contents):
331             return old_contents
332         def _none_modifier(old_contents):
333             return None
334         def _error_modifier(old_contents):
335             raise ValueError("oops")
336         calls = []
337         def _ucw_error_modifier(old_contents):
338             # simulate an UncoordinatedWriteError once
339             calls.append(1)
340             if len(calls) <= 1:
341                 raise UncoordinatedWriteError("simulated")
342             return old_contents + "line3"
343
344         d = self.client.create_mutable_file("line1")
345         def _created(n):
346             d = n.modify(_modifier)
347             d.addCallback(lambda res: n.download_best_version())
348             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
349             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
350
351             d.addCallback(lambda res: n.modify(_non_modifier))
352             d.addCallback(lambda res: n.download_best_version())
353             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
354             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
355
356             d.addCallback(lambda res: n.modify(_none_modifier))
357             d.addCallback(lambda res: n.download_best_version())
358             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
359             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
360
361             d.addCallback(lambda res:
362                           self.shouldFail(ValueError, "error_modifier", None,
363                                           n.modify, _error_modifier))
364             d.addCallback(lambda res: n.download_best_version())
365             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
366             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
367
368             d.addCallback(lambda res: n.modify(_ucw_error_modifier))
369             d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
370             d.addCallback(lambda res: n.download_best_version())
371             d.addCallback(lambda res: self.failUnlessEqual(res,
372                                                            "line1line2line3"))
373             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
374
375             return d
376         d.addCallback(_created)
377         return d
378
379     def test_modify_backoffer(self):
380         def _modifier(old_contents):
381             return old_contents + "line2"
382         calls = []
383         def _ucw_error_modifier(old_contents):
384             # simulate an UncoordinatedWriteError once
385             calls.append(1)
386             if len(calls) <= 1:
387                 raise UncoordinatedWriteError("simulated")
388             return old_contents + "line3"
389         def _always_ucw_error_modifier(old_contents):
390             raise UncoordinatedWriteError("simulated")
391         def _backoff_stopper(node, f):
392             return f
393         def _backoff_pauser(node, f):
394             d = defer.Deferred()
395             reactor.callLater(0.5, d.callback, None)
396             return d
397
398         # the give-up-er will hit its maximum retry count quickly
399         giveuper = BackoffAgent()
400         giveuper._delay = 0.1
401         giveuper.factor = 1
402
403         d = self.client.create_mutable_file("line1")
404         def _created(n):
405             d = n.modify(_modifier)
406             d.addCallback(lambda res: n.download_best_version())
407             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
408             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
409
410             d.addCallback(lambda res:
411                           self.shouldFail(UncoordinatedWriteError,
412                                           "_backoff_stopper", None,
413                                           n.modify, _ucw_error_modifier,
414                                           _backoff_stopper))
415             d.addCallback(lambda res: n.download_best_version())
416             d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
417             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
418
419             def _reset_ucw_error_modifier(res):
420                 calls[:] = []
421                 return res
422             d.addCallback(_reset_ucw_error_modifier)
423             d.addCallback(lambda res: n.modify(_ucw_error_modifier,
424                                                _backoff_pauser))
425             d.addCallback(lambda res: n.download_best_version())
426             d.addCallback(lambda res: self.failUnlessEqual(res,
427                                                            "line1line2line3"))
428             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
429
430             d.addCallback(lambda res:
431                           self.shouldFail(UncoordinatedWriteError,
432                                           "giveuper", None,
433                                           n.modify, _always_ucw_error_modifier,
434                                           giveuper.delay))
435             d.addCallback(lambda res: n.download_best_version())
436             d.addCallback(lambda res: self.failUnlessEqual(res,
437                                                            "line1line2line3"))
438             d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
439
440             return d
441         d.addCallback(_created)
442         return d
443
444     def test_upload_and_download_full_size_keys(self):
445         self.client.mutable_file_node_class = MutableFileNode
446         d = self.client.create_mutable_file()
447         def _created(n):
448             d = defer.succeed(None)
449             d.addCallback(lambda res: n.get_servermap(MODE_READ))
450             d.addCallback(lambda smap: smap.dump(StringIO()))
451             d.addCallback(lambda sio:
452                           self.failUnless("3-of-10" in sio.getvalue()))
453             d.addCallback(lambda res: n.overwrite("contents 1"))
454             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
455             d.addCallback(lambda res: n.download_best_version())
456             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
457             d.addCallback(lambda res: n.overwrite("contents 2"))
458             d.addCallback(lambda res: n.download_best_version())
459             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
460             d.addCallback(lambda res: n.download(download.Data()))
461             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
462             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
463             d.addCallback(lambda smap: n.upload("contents 3", smap))
464             d.addCallback(lambda res: n.download_best_version())
465             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
466             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
467             d.addCallback(lambda smap:
468                           n.download_version(smap,
469                                              smap.best_recoverable_version()))
470             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
471             return d
472         d.addCallback(_created)
473         return d
474
475
476 class MakeShares(unittest.TestCase):
477     def test_encrypt(self):
478         c = FakeClient()
479         fn = FastMutableFileNode(c)
480         CONTENTS = "some initial contents"
481         d = fn.create(CONTENTS)
482         def _created(res):
483             p = Publish(fn, None)
484             p.salt = "SALT" * 4
485             p.readkey = "\x00" * 16
486             p.newdata = CONTENTS
487             p.required_shares = 3
488             p.total_shares = 10
489             p.setup_encoding_parameters()
490             return p._encrypt_and_encode()
491         d.addCallback(_created)
492         def _done(shares_and_shareids):
493             (shares, share_ids) = shares_and_shareids
494             self.failUnlessEqual(len(shares), 10)
495             for sh in shares:
496                 self.failUnless(isinstance(sh, str))
497                 self.failUnlessEqual(len(sh), 7)
498             self.failUnlessEqual(len(share_ids), 10)
499         d.addCallback(_done)
500         return d
501
502     def test_generate(self):
503         c = FakeClient()
504         fn = FastMutableFileNode(c)
505         CONTENTS = "some initial contents"
506         d = fn.create(CONTENTS)
507         def _created(res):
508             p = Publish(fn, None)
509             self._p = p
510             p.newdata = CONTENTS
511             p.required_shares = 3
512             p.total_shares = 10
513             p.setup_encoding_parameters()
514             p._new_seqnum = 3
515             p.salt = "SALT" * 4
516             # make some fake shares
517             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
518             p._privkey = fn.get_privkey()
519             p._encprivkey = fn.get_encprivkey()
520             p._pubkey = fn.get_pubkey()
521             return p._generate_shares(shares_and_ids)
522         d.addCallback(_created)
523         def _generated(res):
524             p = self._p
525             final_shares = p.shares
526             root_hash = p.root_hash
527             self.failUnlessEqual(len(root_hash), 32)
528             self.failUnless(isinstance(final_shares, dict))
529             self.failUnlessEqual(len(final_shares), 10)
530             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
531             for i,sh in final_shares.items():
532                 self.failUnless(isinstance(sh, str))
533                 # feed the share through the unpacker as a sanity-check
534                 pieces = unpack_share(sh)
535                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
536                  pubkey, signature, share_hash_chain, block_hash_tree,
537                  share_data, enc_privkey) = pieces
538                 self.failUnlessEqual(u_seqnum, 3)
539                 self.failUnlessEqual(u_root_hash, root_hash)
540                 self.failUnlessEqual(k, 3)
541                 self.failUnlessEqual(N, 10)
542                 self.failUnlessEqual(segsize, 21)
543                 self.failUnlessEqual(datalen, len(CONTENTS))
544                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
545                 sig_material = struct.pack(">BQ32s16s BBQQ",
546                                            0, p._new_seqnum, root_hash, IV,
547                                            k, N, segsize, datalen)
548                 self.failUnless(p._pubkey.verify(sig_material, signature))
549                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
550                 self.failUnless(isinstance(share_hash_chain, dict))
551                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
552                 for shnum,share_hash in share_hash_chain.items():
553                     self.failUnless(isinstance(shnum, int))
554                     self.failUnless(isinstance(share_hash, str))
555                     self.failUnlessEqual(len(share_hash), 32)
556                 self.failUnless(isinstance(block_hash_tree, list))
557                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
558                 self.failUnlessEqual(IV, "SALT"*4)
559                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
560                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
561         d.addCallback(_generated)
562         return d
563
564     # TODO: when we publish to 20 peers, we should get one share per peer on 10
565     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
566     # when we publish to zero peers, we should get a NotEnoughSharesError
567
568 class Servermap(unittest.TestCase):
569     def setUp(self):
570         # publish a file and create shares, which can then be manipulated
571         # later.
572         num_peers = 20
573         self._client = FakeClient(num_peers)
574         self._storage = self._client._storage
575         d = self._client.create_mutable_file("New contents go here")
576         def _created(node):
577             self._fn = node
578             self._fn2 = self._client.create_node_from_uri(node.get_uri())
579         d.addCallback(_created)
580         return d
581
582     def make_servermap(self, mode=MODE_CHECK, fn=None):
583         if fn is None:
584             fn = self._fn
585         smu = ServermapUpdater(fn, ServerMap(), mode)
586         d = smu.update()
587         return d
588
589     def update_servermap(self, oldmap, mode=MODE_CHECK):
590         smu = ServermapUpdater(self._fn, oldmap, mode)
591         d = smu.update()
592         return d
593
594     def failUnlessOneRecoverable(self, sm, num_shares):
595         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
596         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
597         best = sm.best_recoverable_version()
598         self.failIfEqual(best, None)
599         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
600         self.failUnlessEqual(len(sm.shares_available()), 1)
601         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3))
602         shnum, peerids = sm.make_sharemap().items()[0]
603         peerid = list(peerids)[0]
604         self.failUnlessEqual(sm.version_on_peer(peerid, shnum), best)
605         self.failUnlessEqual(sm.version_on_peer(peerid, 666), None)
606         return sm
607
608     def test_basic(self):
609         d = defer.succeed(None)
610         ms = self.make_servermap
611         us = self.update_servermap
612
613         d.addCallback(lambda res: ms(mode=MODE_CHECK))
614         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
615         d.addCallback(lambda res: ms(mode=MODE_WRITE))
616         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
617         d.addCallback(lambda res: ms(mode=MODE_READ))
618         # this more stops at k+epsilon, and epsilon=k, so 6 shares
619         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
620         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
621         # this mode stops at 'k' shares
622         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
623
624         # and can we re-use the same servermap? Note that these are sorted in
625         # increasing order of number of servers queried, since once a server
626         # gets into the servermap, we'll always ask it for an update.
627         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
628         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
629         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
630         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
631         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
632         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
633         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
634         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
635         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
636
637         return d
638
639     def test_fetch_privkey(self):
640         d = defer.succeed(None)
641         # use the sibling filenode (which hasn't been used yet), and make
642         # sure it can fetch the privkey. The file is small, so the privkey
643         # will be fetched on the first (query) pass.
644         d.addCallback(lambda res: self.make_servermap(MODE_WRITE, self._fn2))
645         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
646
647         # create a new file, which is large enough to knock the privkey out
648         # of the early part of the fil
649         LARGE = "These are Larger contents" * 200 # about 5KB
650         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
651         def _created(large_fn):
652             large_fn2 = self._client.create_node_from_uri(large_fn.get_uri())
653             return self.make_servermap(MODE_WRITE, large_fn2)
654         d.addCallback(_created)
655         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
656         return d
657
658     def test_mark_bad(self):
659         d = defer.succeed(None)
660         ms = self.make_servermap
661         us = self.update_servermap
662
663         d.addCallback(lambda res: ms(mode=MODE_READ))
664         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
665         def _made_map(sm):
666             v = sm.best_recoverable_version()
667             vm = sm.make_versionmap()
668             shares = list(vm[v])
669             self.failUnlessEqual(len(shares), 6)
670             self._corrupted = set()
671             # mark the first 5 shares as corrupt, then update the servermap.
672             # The map should not have the marked shares it in any more, and
673             # new shares should be found to replace the missing ones.
674             for (shnum, peerid, timestamp) in shares:
675                 if shnum < 5:
676                     self._corrupted.add( (peerid, shnum) )
677                     sm.mark_bad_share(peerid, shnum)
678             return self.update_servermap(sm, MODE_WRITE)
679         d.addCallback(_made_map)
680         def _check_map(sm):
681             # this should find all 5 shares that weren't marked bad
682             v = sm.best_recoverable_version()
683             vm = sm.make_versionmap()
684             shares = list(vm[v])
685             for (peerid, shnum) in self._corrupted:
686                 peer_shares = sm.shares_on_peer(peerid)
687                 self.failIf(shnum in peer_shares,
688                             "%d was in %s" % (shnum, peer_shares))
689             self.failUnlessEqual(len(shares), 5)
690         d.addCallback(_check_map)
691         return d
692
693     def failUnlessNoneRecoverable(self, sm):
694         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
695         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
696         best = sm.best_recoverable_version()
697         self.failUnlessEqual(best, None)
698         self.failUnlessEqual(len(sm.shares_available()), 0)
699
700     def test_no_shares(self):
701         self._client._storage._peers = {} # delete all shares
702         ms = self.make_servermap
703         d = defer.succeed(None)
704
705         d.addCallback(lambda res: ms(mode=MODE_CHECK))
706         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
707
708         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
709         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
710
711         d.addCallback(lambda res: ms(mode=MODE_WRITE))
712         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
713
714         d.addCallback(lambda res: ms(mode=MODE_READ))
715         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
716
717         return d
718
719     def failUnlessNotQuiteEnough(self, sm):
720         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
721         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
722         best = sm.best_recoverable_version()
723         self.failUnlessEqual(best, None)
724         self.failUnlessEqual(len(sm.shares_available()), 1)
725         self.failUnlessEqual(sm.shares_available().values()[0], (2,3) )
726         return sm
727
728     def test_not_quite_enough_shares(self):
729         s = self._client._storage
730         ms = self.make_servermap
731         num_shares = len(s._peers)
732         for peerid in s._peers:
733             s._peers[peerid] = {}
734             num_shares -= 1
735             if num_shares == 2:
736                 break
737         # now there ought to be only two shares left
738         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
739
740         d = defer.succeed(None)
741
742         d.addCallback(lambda res: ms(mode=MODE_CHECK))
743         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
744         d.addCallback(lambda sm:
745                       self.failUnlessEqual(len(sm.make_sharemap()), 2))
746         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
747         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
748         d.addCallback(lambda res: ms(mode=MODE_WRITE))
749         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
750         d.addCallback(lambda res: ms(mode=MODE_READ))
751         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
752
753         return d
754
755
756
757 class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
758     def setUp(self):
759         # publish a file and create shares, which can then be manipulated
760         # later.
761         self.CONTENTS = "New contents go here"
762         num_peers = 20
763         self._client = FakeClient(num_peers)
764         self._storage = self._client._storage
765         d = self._client.create_mutable_file(self.CONTENTS)
766         def _created(node):
767             self._fn = node
768         d.addCallback(_created)
769         return d
770
771     def make_servermap(self, mode=MODE_READ, oldmap=None):
772         if oldmap is None:
773             oldmap = ServerMap()
774         smu = ServermapUpdater(self._fn, oldmap, mode)
775         d = smu.update()
776         return d
777
778     def abbrev_verinfo(self, verinfo):
779         if verinfo is None:
780             return None
781         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
782          offsets_tuple) = verinfo
783         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
784
785     def abbrev_verinfo_dict(self, verinfo_d):
786         output = {}
787         for verinfo,value in verinfo_d.items():
788             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
789              offsets_tuple) = verinfo
790             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
791         return output
792
793     def dump_servermap(self, servermap):
794         print "SERVERMAP", servermap
795         print "RECOVERABLE", [self.abbrev_verinfo(v)
796                               for v in servermap.recoverable_versions()]
797         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
798         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
799
800     def do_download(self, servermap, version=None):
801         if version is None:
802             version = servermap.best_recoverable_version()
803         r = Retrieve(self._fn, servermap, version)
804         return r.download()
805
806     def test_basic(self):
807         d = self.make_servermap()
808         def _do_retrieve(servermap):
809             self._smap = servermap
810             #self.dump_servermap(servermap)
811             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
812             return self.do_download(servermap)
813         d.addCallback(_do_retrieve)
814         def _retrieved(new_contents):
815             self.failUnlessEqual(new_contents, self.CONTENTS)
816         d.addCallback(_retrieved)
817         # we should be able to re-use the same servermap, both with and
818         # without updating it.
819         d.addCallback(lambda res: self.do_download(self._smap))
820         d.addCallback(_retrieved)
821         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
822         d.addCallback(lambda res: self.do_download(self._smap))
823         d.addCallback(_retrieved)
824         # clobbering the pubkey should make the servermap updater re-fetch it
825         def _clobber_pubkey(res):
826             self._fn._pubkey = None
827         d.addCallback(_clobber_pubkey)
828         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
829         d.addCallback(lambda res: self.do_download(self._smap))
830         d.addCallback(_retrieved)
831         return d
832
833
834     def _test_corrupt_all(self, offset, substring,
835                           should_succeed=False, corrupt_early=True):
836         d = defer.succeed(None)
837         if corrupt_early:
838             d.addCallback(corrupt, self._storage, offset)
839         d.addCallback(lambda res: self.make_servermap())
840         if not corrupt_early:
841             d.addCallback(corrupt, self._storage, offset)
842         def _do_retrieve(servermap):
843             ver = servermap.best_recoverable_version()
844             if ver is None and not should_succeed:
845                 # no recoverable versions == not succeeding. The problem
846                 # should be noted in the servermap's list of problems.
847                 if substring:
848                     allproblems = [str(f) for f in servermap.problems]
849                     self.failUnless(substring in "".join(allproblems))
850                 return servermap
851             if should_succeed:
852                 d1 = self._fn.download_best_version()
853                 d1.addCallback(lambda new_contents:
854                                self.failUnlessEqual(new_contents, self.CONTENTS))
855             else:
856                 d1 = self.shouldFail(NotEnoughSharesError,
857                                      "_corrupt_all(offset=%s)" % (offset,),
858                                      substring,
859                                      self._fn.download_best_version)
860             d1.addCallback(lambda res: servermap)
861             return d1
862         d.addCallback(_do_retrieve)
863         return d
864
865     def test_corrupt_all_verbyte(self):
866         # when the version byte is not 0, we hit an assertion error in
867         # unpack_share().
868         d = self._test_corrupt_all(0, "AssertionError")
869         def _check_servermap(servermap):
870             # and the dump should mention the problems
871             s = StringIO()
872             dump = servermap.dump(s).getvalue()
873             self.failUnless("10 PROBLEMS" in dump, dump)
874         d.addCallback(_check_servermap)
875         return d
876
877     def test_corrupt_all_seqnum(self):
878         # a corrupt sequence number will trigger a bad signature
879         return self._test_corrupt_all(1, "signature is invalid")
880
881     def test_corrupt_all_R(self):
882         # a corrupt root hash will trigger a bad signature
883         return self._test_corrupt_all(9, "signature is invalid")
884
885     def test_corrupt_all_IV(self):
886         # a corrupt salt/IV will trigger a bad signature
887         return self._test_corrupt_all(41, "signature is invalid")
888
889     def test_corrupt_all_k(self):
890         # a corrupt 'k' will trigger a bad signature
891         return self._test_corrupt_all(57, "signature is invalid")
892
893     def test_corrupt_all_N(self):
894         # a corrupt 'N' will trigger a bad signature
895         return self._test_corrupt_all(58, "signature is invalid")
896
897     def test_corrupt_all_segsize(self):
898         # a corrupt segsize will trigger a bad signature
899         return self._test_corrupt_all(59, "signature is invalid")
900
901     def test_corrupt_all_datalen(self):
902         # a corrupt data length will trigger a bad signature
903         return self._test_corrupt_all(67, "signature is invalid")
904
905     def test_corrupt_all_pubkey(self):
906         # a corrupt pubkey won't match the URI's fingerprint. We need to
907         # remove the pubkey from the filenode, or else it won't bother trying
908         # to update it.
909         self._fn._pubkey = None
910         return self._test_corrupt_all("pubkey",
911                                       "pubkey doesn't match fingerprint")
912
913     def test_corrupt_all_sig(self):
914         # a corrupt signature is a bad one
915         # the signature runs from about [543:799], depending upon the length
916         # of the pubkey
917         return self._test_corrupt_all("signature", "signature is invalid")
918
919     def test_corrupt_all_share_hash_chain_number(self):
920         # a corrupt share hash chain entry will show up as a bad hash. If we
921         # mangle the first byte, that will look like a bad hash number,
922         # causing an IndexError
923         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
924
925     def test_corrupt_all_share_hash_chain_hash(self):
926         # a corrupt share hash chain entry will show up as a bad hash. If we
927         # mangle a few bytes in, that will look like a bad hash.
928         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
929
930     def test_corrupt_all_block_hash_tree(self):
931         return self._test_corrupt_all("block_hash_tree",
932                                       "block hash tree failure")
933
934     def test_corrupt_all_block(self):
935         return self._test_corrupt_all("share_data", "block hash tree failure")
936
937     def test_corrupt_all_encprivkey(self):
938         # a corrupted privkey won't even be noticed by the reader, only by a
939         # writer.
940         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
941
942     def test_basic_pubkey_at_end(self):
943         # we corrupt the pubkey in all but the last 'k' shares, allowing the
944         # download to succeed but forcing a bunch of retries first. Note that
945         # this is rather pessimistic: our Retrieve process will throw away
946         # the whole share if the pubkey is bad, even though the rest of the
947         # share might be good.
948
949         self._fn._pubkey = None
950         k = self._fn.get_required_shares()
951         N = self._fn.get_total_shares()
952         d = defer.succeed(None)
953         d.addCallback(corrupt, self._storage, "pubkey",
954                       shnums_to_corrupt=range(0, N-k))
955         d.addCallback(lambda res: self.make_servermap())
956         def _do_retrieve(servermap):
957             self.failUnless(servermap.problems)
958             self.failUnless("pubkey doesn't match fingerprint"
959                             in str(servermap.problems[0]))
960             ver = servermap.best_recoverable_version()
961             r = Retrieve(self._fn, servermap, ver)
962             return r.download()
963         d.addCallback(_do_retrieve)
964         d.addCallback(lambda new_contents:
965                       self.failUnlessEqual(new_contents, self.CONTENTS))
966         return d
967
968     def test_corrupt_some(self):
969         # corrupt the data of first five shares (so the servermap thinks
970         # they're good but retrieve marks them as bad), so that the
971         # MODE_READ set of 6 will be insufficient, forcing node.download to
972         # retry with more servers.
973         corrupt(None, self._storage, "share_data", range(5))
974         d = self.make_servermap()
975         def _do_retrieve(servermap):
976             ver = servermap.best_recoverable_version()
977             self.failUnless(ver)
978             return self._fn.download_best_version()
979         d.addCallback(_do_retrieve)
980         d.addCallback(lambda new_contents:
981                       self.failUnlessEqual(new_contents, self.CONTENTS))
982         return d
983
984     def test_download_fails(self):
985         corrupt(None, self._storage, "signature")
986         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
987                             "no recoverable versions",
988                             self._fn.download_best_version)
989         return d
990
991
992 class MultipleEncodings(unittest.TestCase):
993     def setUp(self):
994         self.CONTENTS = "New contents go here"
995         num_peers = 20
996         self._client = FakeClient(num_peers)
997         self._storage = self._client._storage
998         d = self._client.create_mutable_file(self.CONTENTS)
999         def _created(node):
1000             self._fn = node
1001         d.addCallback(_created)
1002         return d
1003
1004     def _encode(self, k, n, data):
1005         # encode 'data' into a peerid->shares dict.
1006
1007         fn2 = FastMutableFileNode(self._client)
1008         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
1009         # and _fingerprint
1010         fn = self._fn
1011         fn2.init_from_uri(fn.get_uri())
1012         # then we copy over other fields that are normally fetched from the
1013         # existing shares
1014         fn2._pubkey = fn._pubkey
1015         fn2._privkey = fn._privkey
1016         fn2._encprivkey = fn._encprivkey
1017         # and set the encoding parameters to something completely different
1018         fn2._required_shares = k
1019         fn2._total_shares = n
1020
1021         s = self._client._storage
1022         s._peers = {} # clear existing storage
1023         p2 = Publish(fn2, None)
1024         d = p2.publish(data)
1025         def _published(res):
1026             shares = s._peers
1027             s._peers = {}
1028             return shares
1029         d.addCallback(_published)
1030         return d
1031
1032     def make_servermap(self, mode=MODE_READ, oldmap=None):
1033         if oldmap is None:
1034             oldmap = ServerMap()
1035         smu = ServermapUpdater(self._fn, oldmap, mode)
1036         d = smu.update()
1037         return d
1038
1039     def test_multiple_encodings(self):
1040         # we encode the same file in two different ways (3-of-10 and 4-of-9),
1041         # then mix up the shares, to make sure that download survives seeing
1042         # a variety of encodings. This is actually kind of tricky to set up.
1043
1044         contents1 = "Contents for encoding 1 (3-of-10) go here"
1045         contents2 = "Contents for encoding 2 (4-of-9) go here"
1046         contents3 = "Contents for encoding 3 (4-of-7) go here"
1047
1048         # we make a retrieval object that doesn't know what encoding
1049         # parameters to use
1050         fn3 = FastMutableFileNode(self._client)
1051         fn3.init_from_uri(self._fn.get_uri())
1052
1053         # now we upload a file through fn1, and grab its shares
1054         d = self._encode(3, 10, contents1)
1055         def _encoded_1(shares):
1056             self._shares1 = shares
1057         d.addCallback(_encoded_1)
1058         d.addCallback(lambda res: self._encode(4, 9, contents2))
1059         def _encoded_2(shares):
1060             self._shares2 = shares
1061         d.addCallback(_encoded_2)
1062         d.addCallback(lambda res: self._encode(4, 7, contents3))
1063         def _encoded_3(shares):
1064             self._shares3 = shares
1065         d.addCallback(_encoded_3)
1066
1067         def _merge(res):
1068             log.msg("merging sharelists")
1069             # we merge the shares from the two sets, leaving each shnum in
1070             # its original location, but using a share from set1 or set2
1071             # according to the following sequence:
1072             #
1073             #  4-of-9  a  s2
1074             #  4-of-9  b  s2
1075             #  4-of-7  c   s3
1076             #  4-of-9  d  s2
1077             #  3-of-9  e s1
1078             #  3-of-9  f s1
1079             #  3-of-9  g s1
1080             #  4-of-9  h  s2
1081             #
1082             # so that neither form can be recovered until fetch [f], at which
1083             # point version-s1 (the 3-of-10 form) should be recoverable. If
1084             # the implementation latches on to the first version it sees,
1085             # then s2 will be recoverable at fetch [g].
1086
1087             # Later, when we implement code that handles multiple versions,
1088             # we can use this framework to assert that all recoverable
1089             # versions are retrieved, and test that 'epsilon' does its job
1090
1091             places = [2, 2, 3, 2, 1, 1, 1, 2]
1092
1093             sharemap = {}
1094
1095             for i,peerid in enumerate(self._client._peerids):
1096                 peerid_s = shortnodeid_b2a(peerid)
1097                 for shnum in self._shares1.get(peerid, {}):
1098                     if shnum < len(places):
1099                         which = places[shnum]
1100                     else:
1101                         which = "x"
1102                     self._client._storage._peers[peerid] = peers = {}
1103                     in_1 = shnum in self._shares1[peerid]
1104                     in_2 = shnum in self._shares2.get(peerid, {})
1105                     in_3 = shnum in self._shares3.get(peerid, {})
1106                     #print peerid_s, shnum, which, in_1, in_2, in_3
1107                     if which == 1:
1108                         if in_1:
1109                             peers[shnum] = self._shares1[peerid][shnum]
1110                             sharemap[shnum] = peerid
1111                     elif which == 2:
1112                         if in_2:
1113                             peers[shnum] = self._shares2[peerid][shnum]
1114                             sharemap[shnum] = peerid
1115                     elif which == 3:
1116                         if in_3:
1117                             peers[shnum] = self._shares3[peerid][shnum]
1118                             sharemap[shnum] = peerid
1119
1120             # we don't bother placing any other shares
1121             # now sort the sequence so that share 0 is returned first
1122             new_sequence = [sharemap[shnum]
1123                             for shnum in sorted(sharemap.keys())]
1124             self._client._storage._sequence = new_sequence
1125             log.msg("merge done")
1126         d.addCallback(_merge)
1127         d.addCallback(lambda res: fn3.download_best_version())
1128         def _retrieved(new_contents):
1129             # the current specified behavior is "first version recoverable"
1130             self.failUnlessEqual(new_contents, contents1)
1131         d.addCallback(_retrieved)
1132         return d
1133
1134 class MultipleVersions(unittest.TestCase):
1135     def setUp(self):
1136         self.CONTENTS = ["Contents 0",
1137                          "Contents 1",
1138                          "Contents 2",
1139                          "Contents 3a",
1140                          "Contents 3b"]
1141         self._copied_shares = {}
1142         num_peers = 20
1143         self._client = FakeClient(num_peers)
1144         self._storage = self._client._storage
1145         d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
1146         def _created(node):
1147             self._fn = node
1148             # now create multiple versions of the same file, and accumulate
1149             # their shares, so we can mix and match them later.
1150             d = defer.succeed(None)
1151             d.addCallback(self._copy_shares, 0)
1152             d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
1153             d.addCallback(self._copy_shares, 1)
1154             d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
1155             d.addCallback(self._copy_shares, 2)
1156             d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
1157             d.addCallback(self._copy_shares, 3)
1158             # now we replace all the shares with version s3, and upload a new
1159             # version to get s4b.
1160             rollback = dict([(i,2) for i in range(10)])
1161             d.addCallback(lambda res: self._set_versions(rollback))
1162             d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
1163             d.addCallback(self._copy_shares, 4)
1164             # we leave the storage in state 4
1165             return d
1166         d.addCallback(_created)
1167         return d
1168
1169     def _copy_shares(self, ignored, index):
1170         shares = self._client._storage._peers
1171         # we need a deep copy
1172         new_shares = {}
1173         for peerid in shares:
1174             new_shares[peerid] = {}
1175             for shnum in shares[peerid]:
1176                 new_shares[peerid][shnum] = shares[peerid][shnum]
1177         self._copied_shares[index] = new_shares
1178
1179     def _set_versions(self, versionmap):
1180         # versionmap maps shnums to which version (0,1,2,3,4) we want the
1181         # share to be at. Any shnum which is left out of the map will stay at
1182         # its current version.
1183         shares = self._client._storage._peers
1184         oldshares = self._copied_shares
1185         for peerid in shares:
1186             for shnum in shares[peerid]:
1187                 if shnum in versionmap:
1188                     index = versionmap[shnum]
1189                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
1190
1191     def test_multiple_versions(self):
1192         # if we see a mix of versions in the grid, download_best_version
1193         # should get the latest one
1194         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
1195         d = self._fn.download_best_version()
1196         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
1197         # but if everything is at version 2, that's what we should download
1198         d.addCallback(lambda res:
1199                       self._set_versions(dict([(i,2) for i in range(10)])))
1200         d.addCallback(lambda res: self._fn.download_best_version())
1201         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1202         # if exactly one share is at version 3, we should still get v2
1203         d.addCallback(lambda res:
1204                       self._set_versions({0:3}))
1205         d.addCallback(lambda res: self._fn.download_best_version())
1206         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[2]))
1207         # but the servermap should see the unrecoverable version. This
1208         # depends upon the single newer share being queried early.
1209         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1210         def _check_smap(smap):
1211             self.failUnlessEqual(len(smap.unrecoverable_versions()), 1)
1212             newer = smap.unrecoverable_newer_versions()
1213             self.failUnlessEqual(len(newer), 1)
1214             verinfo, health = newer.items()[0]
1215             self.failUnlessEqual(verinfo[0], 4)
1216             self.failUnlessEqual(health, (1,3))
1217             self.failIf(smap.needs_merge())
1218         d.addCallback(_check_smap)
1219         # if we have a mix of two parallel versions (s4a and s4b), we could
1220         # recover either
1221         d.addCallback(lambda res:
1222                       self._set_versions({0:3,2:3,4:3,6:3,8:3,
1223                                           1:4,3:4,5:4,7:4,9:4}))
1224         d.addCallback(lambda res: self._fn.get_servermap(MODE_READ))
1225         def _check_smap_mixed(smap):
1226             self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
1227             newer = smap.unrecoverable_newer_versions()
1228             self.failUnlessEqual(len(newer), 0)
1229             self.failUnless(smap.needs_merge())
1230         d.addCallback(_check_smap_mixed)
1231         d.addCallback(lambda res: self._fn.download_best_version())
1232         d.addCallback(lambda res: self.failUnless(res == self.CONTENTS[3] or
1233                                                   res == self.CONTENTS[4]))
1234         return d
1235
1236
1237 class Utils(unittest.TestCase):
1238     def test_dict_of_sets(self):
1239         ds = DictOfSets()
1240         ds.add(1, "a")
1241         ds.add(2, "b")
1242         ds.add(2, "b")
1243         ds.add(2, "c")
1244         self.failUnlessEqual(ds[1], set(["a"]))
1245         self.failUnlessEqual(ds[2], set(["b", "c"]))
1246         ds.discard(3, "d") # should not raise an exception
1247         ds.discard(2, "b")
1248         self.failUnlessEqual(ds[2], set(["c"]))
1249         ds.discard(2, "c")
1250         self.failIf(2 in ds)
1251
1252     def _do_inside(self, c, x_start, x_length, y_start, y_length):
1253         # we compare this against sets of integers
1254         x = set(range(x_start, x_start+x_length))
1255         y = set(range(y_start, y_start+y_length))
1256         should_be_inside = x.issubset(y)
1257         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
1258                                                          y_start, y_length),
1259                              str((x_start, x_length, y_start, y_length)))
1260
1261     def test_cache_inside(self):
1262         c = ResponseCache()
1263         x_start = 10
1264         x_length = 5
1265         for y_start in range(8, 17):
1266             for y_length in range(8):
1267                 self._do_inside(c, x_start, x_length, y_start, y_length)
1268
1269     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1270         # we compare this against sets of integers
1271         x = set(range(x_start, x_start+x_length))
1272         y = set(range(y_start, y_start+y_length))
1273         overlap = bool(x.intersection(y))
1274         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1275                                                       y_start, y_length),
1276                              str((x_start, x_length, y_start, y_length)))
1277
1278     def test_cache_overlap(self):
1279         c = ResponseCache()
1280         x_start = 10
1281         x_length = 5
1282         for y_start in range(8, 17):
1283             for y_length in range(8):
1284                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1285
1286     def test_cache(self):
1287         c = ResponseCache()
1288         # xdata = base62.b2a(os.urandom(100))[:100]
1289         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1290         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1291         nope = (None, None)
1292         c.add("v1", 1, 0, xdata, "time0")
1293         c.add("v1", 1, 2000, ydata, "time1")
1294         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1295         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1296         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1297         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1298         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1299         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1300         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1301         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1302         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1303         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1304         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1305         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1306         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1307         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1308         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1309         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1310         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1311         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1312
1313         # optional: join fragments
1314         c = ResponseCache()
1315         c.add("v1", 1, 0, xdata[:10], "time0")
1316         c.add("v1", 1, 10, xdata[10:20], "time1")
1317         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1318