]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_mutable.py
mutable: replace MutableFileNode API, update tests. Changed all callers to use overwr...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_mutable.py
1
2 import struct
3 from cStringIO import StringIO
4 from twisted.trial import unittest
5 from twisted.internet import defer, reactor
6 from twisted.python import failure
7 from allmydata import uri, download
8 from allmydata.util import base32
9 from allmydata.util.idlib import shortnodeid_b2a
10 from allmydata.util.hashutil import tagged_hash
11 from allmydata.encode import NotEnoughSharesError
12 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable
13 from foolscap.eventual import eventually, fireEventually
14 from foolscap.logging import log
15 import sha
16
17 from allmydata.mutable.node import MutableFileNode
18 from allmydata.mutable.common import DictOfSets, ResponseCache, \
19      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, UnrecoverableFileError
20 from allmydata.mutable.retrieve import Retrieve
21 from allmydata.mutable.publish import Publish
22 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
23 from allmydata.mutable.layout import unpack_header, unpack_share
24
25 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
26 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
27 # of this.
28
29 class FastMutableFileNode(MutableFileNode):
30     SIGNATURE_KEY_SIZE = 522
31
32 # this "FakeStorage" exists to put the share data in RAM and avoid using real
33 # network connections, both to speed up the tests and to reduce the amount of
34 # non-mutable.py code being exercised.
35
36 class FakeStorage:
37     # this class replaces the collection of storage servers, allowing the
38     # tests to examine and manipulate the published shares. It also lets us
39     # control the order in which read queries are answered, to exercise more
40     # of the error-handling code in Retrieve .
41     #
42     # Note that we ignore the storage index: this FakeStorage instance can
43     # only be used for a single storage index.
44
45
46     def __init__(self):
47         self._peers = {}
48         # _sequence is used to cause the responses to occur in a specific
49         # order. If it is in use, then we will defer queries instead of
50         # answering them right away, accumulating the Deferreds in a dict. We
51         # don't know exactly how many queries we'll get, so exactly one
52         # second after the first query arrives, we will release them all (in
53         # order).
54         self._sequence = None
55         self._pending = {}
56
57     def read(self, peerid, storage_index):
58         shares = self._peers.get(peerid, {})
59         if self._sequence is None:
60             return defer.succeed(shares)
61         d = defer.Deferred()
62         if not self._pending:
63             reactor.callLater(1.0, self._fire_readers)
64         self._pending[peerid] = (d, shares)
65         return d
66
67     def _fire_readers(self):
68         pending = self._pending
69         self._pending = {}
70         extra = []
71         for peerid in self._sequence:
72             if peerid in pending:
73                 d, shares = pending.pop(peerid)
74                 eventually(d.callback, shares)
75         for (d, shares) in pending.values():
76             eventually(d.callback, shares)
77
78     def write(self, peerid, storage_index, shnum, offset, data):
79         if peerid not in self._peers:
80             self._peers[peerid] = {}
81         shares = self._peers[peerid]
82         f = StringIO()
83         f.write(shares.get(shnum, ""))
84         f.seek(offset)
85         f.write(data)
86         shares[shnum] = f.getvalue()
87
88
89 class FakeStorageServer:
90     def __init__(self, peerid, storage):
91         self.peerid = peerid
92         self.storage = storage
93         self.queries = 0
94     def callRemote(self, methname, *args, **kwargs):
95         def _call():
96             meth = getattr(self, methname)
97             return meth(*args, **kwargs)
98         d = fireEventually()
99         d.addCallback(lambda res: _call())
100         return d
101
102     def slot_readv(self, storage_index, shnums, readv):
103         d = self.storage.read(self.peerid, storage_index)
104         def _read(shares):
105             response = {}
106             for shnum in shares:
107                 if shnums and shnum not in shnums:
108                     continue
109                 vector = response[shnum] = []
110                 for (offset, length) in readv:
111                     assert isinstance(offset, (int, long)), offset
112                     assert isinstance(length, (int, long)), length
113                     vector.append(shares[shnum][offset:offset+length])
114             return response
115         d.addCallback(_read)
116         return d
117
118     def slot_testv_and_readv_and_writev(self, storage_index, secrets,
119                                         tw_vectors, read_vector):
120         # always-pass: parrot the test vectors back to them.
121         readv = {}
122         for shnum, (testv, writev, new_length) in tw_vectors.items():
123             for (offset, length, op, specimen) in testv:
124                 assert op in ("le", "eq", "ge")
125             # TODO: this isn't right, the read is controlled by read_vector,
126             # not by testv
127             readv[shnum] = [ specimen
128                              for (offset, length, op, specimen)
129                              in testv ]
130             for (offset, data) in writev:
131                 self.storage.write(self.peerid, storage_index, shnum,
132                                    offset, data)
133         answer = (True, readv)
134         return fireEventually(answer)
135
136
137 # our "FakeClient" has just enough functionality of the real Client to let
138 # the tests run.
139
140 class FakeClient:
141     mutable_file_node_class = FastMutableFileNode
142
143     def __init__(self, num_peers=10):
144         self._storage = FakeStorage()
145         self._num_peers = num_peers
146         self._peerids = [tagged_hash("peerid", "%d" % i)[:20]
147                          for i in range(self._num_peers)]
148         self._connections = dict([(peerid, FakeStorageServer(peerid,
149                                                              self._storage))
150                                   for peerid in self._peerids])
151         self.nodeid = "fakenodeid"
152
153     def log(self, msg, **kw):
154         return log.msg(msg, **kw)
155
156     def get_renewal_secret(self):
157         return "I hereby permit you to renew my files"
158     def get_cancel_secret(self):
159         return "I hereby permit you to cancel my leases"
160
161     def create_mutable_file(self, contents=""):
162         n = self.mutable_file_node_class(self)
163         d = n.create(contents)
164         d.addCallback(lambda res: n)
165         return d
166
167     def notify_retrieve(self, r):
168         pass
169     def notify_publish(self, p):
170         pass
171     def notify_mapupdate(self, u):
172         pass
173
174     def create_node_from_uri(self, u):
175         u = IURI(u)
176         assert IMutableFileURI.providedBy(u), u
177         res = self.mutable_file_node_class(self).init_from_uri(u)
178         return res
179
180     def get_permuted_peers(self, service_name, key):
181         """
182         @return: list of (peerid, connection,)
183         """
184         results = []
185         for (peerid, connection) in self._connections.items():
186             assert isinstance(peerid, str)
187             permuted = sha.new(key + peerid).digest()
188             results.append((permuted, peerid, connection))
189         results.sort()
190         results = [ (r[1],r[2]) for r in results]
191         return results
192
193     def upload(self, uploadable):
194         assert IUploadable.providedBy(uploadable)
195         d = uploadable.get_size()
196         d.addCallback(lambda length: uploadable.read(length))
197         #d.addCallback(self.create_mutable_file)
198         def _got_data(datav):
199             data = "".join(datav)
200             #newnode = FastMutableFileNode(self)
201             return uri.LiteralFileURI(data)
202         d.addCallback(_got_data)
203         return d
204
205
206 def flip_bit(original, byte_offset):
207     return (original[:byte_offset] +
208             chr(ord(original[byte_offset]) ^ 0x01) +
209             original[byte_offset+1:])
210
211 def corrupt(res, s, offset, shnums_to_corrupt=None):
212     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
213     # list of shnums to corrupt.
214     for peerid in s._peers:
215         shares = s._peers[peerid]
216         for shnum in shares:
217             if (shnums_to_corrupt is not None
218                 and shnum not in shnums_to_corrupt):
219                 continue
220             data = shares[shnum]
221             (version,
222              seqnum,
223              root_hash,
224              IV,
225              k, N, segsize, datalen,
226              o) = unpack_header(data)
227             if isinstance(offset, tuple):
228                 offset1, offset2 = offset
229             else:
230                 offset1 = offset
231                 offset2 = 0
232             if offset1 == "pubkey":
233                 real_offset = 107
234             elif offset1 in o:
235                 real_offset = o[offset1]
236             else:
237                 real_offset = offset1
238             real_offset = int(real_offset) + offset2
239             assert isinstance(real_offset, int), offset
240             shares[shnum] = flip_bit(data, real_offset)
241     return res
242
243 class Filenode(unittest.TestCase):
244     def setUp(self):
245         self.client = FakeClient()
246
247     def test_create(self):
248         d = self.client.create_mutable_file()
249         def _created(n):
250             self.failUnless(isinstance(n, FastMutableFileNode))
251             peer0 = self.client._peerids[0]
252             shnums = self.client._storage._peers[peer0].keys()
253             self.failUnlessEqual(len(shnums), 1)
254         d.addCallback(_created)
255         return d
256
257     def test_upload_and_download(self):
258         d = self.client.create_mutable_file()
259         def _created(n):
260             d = defer.succeed(None)
261             d.addCallback(lambda res: n.get_servermap(MODE_READ))
262             d.addCallback(lambda smap: smap.dump(StringIO()))
263             d.addCallback(lambda sio:
264                           self.failUnless("3-of-10" in sio.getvalue()))
265             d.addCallback(lambda res: n.overwrite("contents 1"))
266             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
267             d.addCallback(lambda res: n.download_best_version())
268             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
269             d.addCallback(lambda res: n.overwrite("contents 2"))
270             d.addCallback(lambda res: n.download_best_version())
271             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
272             d.addCallback(lambda res: n.download(download.Data()))
273             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
274             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
275             d.addCallback(lambda smap: n.upload("contents 3", smap))
276             d.addCallback(lambda res: n.download_best_version())
277             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
278             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
279             d.addCallback(lambda smap:
280                           n.download_version(smap,
281                                              smap.best_recoverable_version()))
282             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
283             return d
284         d.addCallback(_created)
285         return d
286
287     def test_create_with_initial_contents(self):
288         d = self.client.create_mutable_file("contents 1")
289         def _created(n):
290             d = n.download_best_version()
291             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
292             d.addCallback(lambda res: n.overwrite("contents 2"))
293             d.addCallback(lambda res: n.download_best_version())
294             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
295             return d
296         d.addCallback(_created)
297         return d
298
299     def test_upload_and_download_full_size_keys(self):
300         self.client.mutable_file_node_class = MutableFileNode
301         d = self.client.create_mutable_file()
302         def _created(n):
303             d = defer.succeed(None)
304             d.addCallback(lambda res: n.get_servermap(MODE_READ))
305             d.addCallback(lambda smap: smap.dump(StringIO()))
306             d.addCallback(lambda sio:
307                           self.failUnless("3-of-10" in sio.getvalue()))
308             d.addCallback(lambda res: n.overwrite("contents 1"))
309             d.addCallback(lambda res: self.failUnlessIdentical(res, None))
310             d.addCallback(lambda res: n.download_best_version())
311             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 1"))
312             d.addCallback(lambda res: n.overwrite("contents 2"))
313             d.addCallback(lambda res: n.download_best_version())
314             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
315             d.addCallback(lambda res: n.download(download.Data()))
316             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
317             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
318             d.addCallback(lambda smap: n.upload("contents 3", smap))
319             d.addCallback(lambda res: n.download_best_version())
320             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
321             d.addCallback(lambda res: n.get_servermap(MODE_ANYTHING))
322             d.addCallback(lambda smap:
323                           n.download_version(smap,
324                                              smap.best_recoverable_version()))
325             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 3"))
326             return d
327         d.addCallback(_created)
328         return d
329
330
331 class MakeShares(unittest.TestCase):
332     def test_encrypt(self):
333         c = FakeClient()
334         fn = FastMutableFileNode(c)
335         CONTENTS = "some initial contents"
336         d = fn.create(CONTENTS)
337         def _created(res):
338             p = Publish(fn, None)
339             p.salt = "SALT" * 4
340             p.readkey = "\x00" * 16
341             p.newdata = CONTENTS
342             p.required_shares = 3
343             p.total_shares = 10
344             p.setup_encoding_parameters()
345             return p._encrypt_and_encode()
346         d.addCallback(_created)
347         def _done(shares_and_shareids):
348             (shares, share_ids) = shares_and_shareids
349             self.failUnlessEqual(len(shares), 10)
350             for sh in shares:
351                 self.failUnless(isinstance(sh, str))
352                 self.failUnlessEqual(len(sh), 7)
353             self.failUnlessEqual(len(share_ids), 10)
354         d.addCallback(_done)
355         return d
356
357     def test_generate(self):
358         c = FakeClient()
359         fn = FastMutableFileNode(c)
360         CONTENTS = "some initial contents"
361         d = fn.create(CONTENTS)
362         def _created(res):
363             p = Publish(fn, None)
364             self._p = p
365             p.newdata = CONTENTS
366             p.required_shares = 3
367             p.total_shares = 10
368             p.setup_encoding_parameters()
369             p._new_seqnum = 3
370             p.salt = "SALT" * 4
371             # make some fake shares
372             shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) )
373             p._privkey = fn.get_privkey()
374             p._encprivkey = fn.get_encprivkey()
375             p._pubkey = fn.get_pubkey()
376             return p._generate_shares(shares_and_ids)
377         d.addCallback(_created)
378         def _generated(res):
379             p = self._p
380             final_shares = p.shares
381             root_hash = p.root_hash
382             self.failUnlessEqual(len(root_hash), 32)
383             self.failUnless(isinstance(final_shares, dict))
384             self.failUnlessEqual(len(final_shares), 10)
385             self.failUnlessEqual(sorted(final_shares.keys()), range(10))
386             for i,sh in final_shares.items():
387                 self.failUnless(isinstance(sh, str))
388                 # feed the share through the unpacker as a sanity-check
389                 pieces = unpack_share(sh)
390                 (u_seqnum, u_root_hash, IV, k, N, segsize, datalen,
391                  pubkey, signature, share_hash_chain, block_hash_tree,
392                  share_data, enc_privkey) = pieces
393                 self.failUnlessEqual(u_seqnum, 3)
394                 self.failUnlessEqual(u_root_hash, root_hash)
395                 self.failUnlessEqual(k, 3)
396                 self.failUnlessEqual(N, 10)
397                 self.failUnlessEqual(segsize, 21)
398                 self.failUnlessEqual(datalen, len(CONTENTS))
399                 self.failUnlessEqual(pubkey, p._pubkey.serialize())
400                 sig_material = struct.pack(">BQ32s16s BBQQ",
401                                            0, p._new_seqnum, root_hash, IV,
402                                            k, N, segsize, datalen)
403                 self.failUnless(p._pubkey.verify(sig_material, signature))
404                 #self.failUnlessEqual(signature, p._privkey.sign(sig_material))
405                 self.failUnless(isinstance(share_hash_chain, dict))
406                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
407                 for shnum,share_hash in share_hash_chain.items():
408                     self.failUnless(isinstance(shnum, int))
409                     self.failUnless(isinstance(share_hash, str))
410                     self.failUnlessEqual(len(share_hash), 32)
411                 self.failUnless(isinstance(block_hash_tree, list))
412                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
413                 self.failUnlessEqual(IV, "SALT"*4)
414                 self.failUnlessEqual(len(share_data), len("%07d" % 1))
415                 self.failUnlessEqual(enc_privkey, fn.get_encprivkey())
416         d.addCallback(_generated)
417         return d
418
419     # TODO: when we publish to 20 peers, we should get one share per peer on 10
420     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
421     # when we publish to zero peers, we should get a NotEnoughSharesError
422
423 class Servermap(unittest.TestCase):
424     def setUp(self):
425         # publish a file and create shares, which can then be manipulated
426         # later.
427         num_peers = 20
428         self._client = FakeClient(num_peers)
429         self._storage = self._client._storage
430         d = self._client.create_mutable_file("New contents go here")
431         def _created(node):
432             self._fn = node
433         d.addCallback(_created)
434         return d
435
436     def make_servermap(self, mode=MODE_CHECK):
437         smu = ServermapUpdater(self._fn, ServerMap(), mode)
438         d = smu.update()
439         return d
440
441     def update_servermap(self, oldmap, mode=MODE_CHECK):
442         smu = ServermapUpdater(self._fn, oldmap, mode)
443         d = smu.update()
444         return d
445
446     def failUnlessOneRecoverable(self, sm, num_shares):
447         self.failUnlessEqual(len(sm.recoverable_versions()), 1)
448         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
449         best = sm.best_recoverable_version()
450         self.failIfEqual(best, None)
451         self.failUnlessEqual(sm.recoverable_versions(), set([best]))
452         self.failUnlessEqual(len(sm.shares_available()), 1)
453         self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3))
454         return sm
455
456     def test_basic(self):
457         d = defer.succeed(None)
458         ms = self.make_servermap
459         us = self.update_servermap
460
461         d.addCallback(lambda res: ms(mode=MODE_CHECK))
462         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
463         d.addCallback(lambda res: ms(mode=MODE_WRITE))
464         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
465         d.addCallback(lambda res: ms(mode=MODE_READ))
466         # this more stops at k+epsilon, and epsilon=k, so 6 shares
467         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
468         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
469         # this mode stops at 'k' shares
470         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
471
472         # and can we re-use the same servermap? Note that these are sorted in
473         # increasing order of number of servers queried, since once a server
474         # gets into the servermap, we'll always ask it for an update.
475         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3))
476         d.addCallback(lambda sm: us(sm, mode=MODE_READ))
477         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
478         d.addCallback(lambda sm: us(sm, mode=MODE_WRITE))
479         d.addCallback(lambda sm: us(sm, mode=MODE_CHECK))
480         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
481         d.addCallback(lambda sm: us(sm, mode=MODE_ANYTHING))
482         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
483
484         return d
485
486     def test_mark_bad(self):
487         d = defer.succeed(None)
488         ms = self.make_servermap
489         us = self.update_servermap
490
491         d.addCallback(lambda res: ms(mode=MODE_READ))
492         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
493         def _made_map(sm):
494             v = sm.best_recoverable_version()
495             vm = sm.make_versionmap()
496             shares = list(vm[v])
497             self.failUnlessEqual(len(shares), 6)
498             self._corrupted = set()
499             # mark the first 5 shares as corrupt, then update the servermap.
500             # The map should not have the marked shares it in any more, and
501             # new shares should be found to replace the missing ones.
502             for (shnum, peerid, timestamp) in shares:
503                 if shnum < 5:
504                     self._corrupted.add( (peerid, shnum) )
505                     sm.mark_bad_share(peerid, shnum)
506             return self.update_servermap(sm, MODE_WRITE)
507         d.addCallback(_made_map)
508         def _check_map(sm):
509             # this should find all 5 shares that weren't marked bad
510             v = sm.best_recoverable_version()
511             vm = sm.make_versionmap()
512             shares = list(vm[v])
513             for (peerid, shnum) in self._corrupted:
514                 peer_shares = sm.shares_on_peer(peerid)
515                 self.failIf(shnum in peer_shares,
516                             "%d was in %s" % (shnum, peer_shares))
517             self.failUnlessEqual(len(shares), 5)
518         d.addCallback(_check_map)
519         return d
520
521     def failUnlessNoneRecoverable(self, sm):
522         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
523         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
524         best = sm.best_recoverable_version()
525         self.failUnlessEqual(best, None)
526         self.failUnlessEqual(len(sm.shares_available()), 0)
527
528     def test_no_shares(self):
529         self._client._storage._peers = {} # delete all shares
530         ms = self.make_servermap
531         d = defer.succeed(None)
532
533         d.addCallback(lambda res: ms(mode=MODE_CHECK))
534         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
535
536         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
537         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
538
539         d.addCallback(lambda res: ms(mode=MODE_WRITE))
540         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
541
542         d.addCallback(lambda res: ms(mode=MODE_READ))
543         d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm))
544
545         return d
546
547     def failUnlessNotQuiteEnough(self, sm):
548         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
549         self.failUnlessEqual(len(sm.unrecoverable_versions()), 1)
550         best = sm.best_recoverable_version()
551         self.failUnlessEqual(best, None)
552         self.failUnlessEqual(len(sm.shares_available()), 1)
553         self.failUnlessEqual(sm.shares_available().values()[0], (2,3) )
554
555     def test_not_quite_enough_shares(self):
556         s = self._client._storage
557         ms = self.make_servermap
558         num_shares = len(s._peers)
559         for peerid in s._peers:
560             s._peers[peerid] = {}
561             num_shares -= 1
562             if num_shares == 2:
563                 break
564         # now there ought to be only two shares left
565         assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2
566
567         d = defer.succeed(None)
568
569         d.addCallback(lambda res: ms(mode=MODE_CHECK))
570         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
571         d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
572         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
573         d.addCallback(lambda res: ms(mode=MODE_WRITE))
574         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
575         d.addCallback(lambda res: ms(mode=MODE_READ))
576         d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
577
578         return d
579
580
581
582 class Roundtrip(unittest.TestCase):
583     def setUp(self):
584         # publish a file and create shares, which can then be manipulated
585         # later.
586         self.CONTENTS = "New contents go here"
587         num_peers = 20
588         self._client = FakeClient(num_peers)
589         self._storage = self._client._storage
590         d = self._client.create_mutable_file(self.CONTENTS)
591         def _created(node):
592             self._fn = node
593         d.addCallback(_created)
594         return d
595
596     def make_servermap(self, mode=MODE_READ, oldmap=None):
597         if oldmap is None:
598             oldmap = ServerMap()
599         smu = ServermapUpdater(self._fn, oldmap, mode)
600         d = smu.update()
601         return d
602
603     def abbrev_verinfo(self, verinfo):
604         if verinfo is None:
605             return None
606         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
607          offsets_tuple) = verinfo
608         return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4])
609
610     def abbrev_verinfo_dict(self, verinfo_d):
611         output = {}
612         for verinfo,value in verinfo_d.items():
613             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
614              offsets_tuple) = verinfo
615             output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value
616         return output
617
618     def dump_servermap(self, servermap):
619         print "SERVERMAP", servermap
620         print "RECOVERABLE", [self.abbrev_verinfo(v)
621                               for v in servermap.recoverable_versions()]
622         print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version())
623         print "available", self.abbrev_verinfo_dict(servermap.shares_available())
624
625     def do_download(self, servermap, version=None):
626         if version is None:
627             version = servermap.best_recoverable_version()
628         r = Retrieve(self._fn, servermap, version)
629         return r.download()
630
631     def test_basic(self):
632         d = self.make_servermap()
633         def _do_retrieve(servermap):
634             self._smap = servermap
635             #self.dump_servermap(servermap)
636             self.failUnlessEqual(len(servermap.recoverable_versions()), 1)
637             return self.do_download(servermap)
638         d.addCallback(_do_retrieve)
639         def _retrieved(new_contents):
640             self.failUnlessEqual(new_contents, self.CONTENTS)
641         d.addCallback(_retrieved)
642         # we should be able to re-use the same servermap, both with and
643         # without updating it.
644         d.addCallback(lambda res: self.do_download(self._smap))
645         d.addCallback(_retrieved)
646         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
647         d.addCallback(lambda res: self.do_download(self._smap))
648         d.addCallback(_retrieved)
649         # clobbering the pubkey should make the servermap updater re-fetch it
650         def _clobber_pubkey(res):
651             self._fn._pubkey = None
652         d.addCallback(_clobber_pubkey)
653         d.addCallback(lambda res: self.make_servermap(oldmap=self._smap))
654         d.addCallback(lambda res: self.do_download(self._smap))
655         d.addCallback(_retrieved)
656         return d
657
658
659     def shouldFail(self, expected_failure, which, substring,
660                     callable, *args, **kwargs):
661         assert substring is None or isinstance(substring, str)
662         d = defer.maybeDeferred(callable, *args, **kwargs)
663         def done(res):
664             if isinstance(res, failure.Failure):
665                 res.trap(expected_failure)
666                 if substring:
667                     self.failUnless(substring in str(res),
668                                     "substring '%s' not in '%s'"
669                                     % (substring, str(res)))
670             else:
671                 self.fail("%s was supposed to raise %s, not get '%s'" %
672                           (which, expected_failure, res))
673         d.addBoth(done)
674         return d
675
676     def _test_corrupt_all(self, offset, substring,
677                           should_succeed=False, corrupt_early=True):
678         d = defer.succeed(None)
679         if corrupt_early:
680             d.addCallback(corrupt, self._storage, offset)
681         d.addCallback(lambda res: self.make_servermap())
682         if not corrupt_early:
683             d.addCallback(corrupt, self._storage, offset)
684         def _do_retrieve(servermap):
685             ver = servermap.best_recoverable_version()
686             if ver is None and not should_succeed:
687                 # no recoverable versions == not succeeding. The problem
688                 # should be noted in the servermap's list of problems.
689                 if substring:
690                     allproblems = [str(f) for f in servermap.problems]
691                     self.failUnless(substring in "".join(allproblems))
692                 return
693             if should_succeed:
694                 d1 = self._fn.download_best_version()
695                 d1.addCallback(lambda new_contents:
696                                self.failUnlessEqual(new_contents, self.CONTENTS))
697                 return d1
698             else:
699                 return self.shouldFail(NotEnoughSharesError,
700                                        "_corrupt_all(offset=%s)" % (offset,),
701                                        substring,
702                                        self._fn.download_best_version)
703         d.addCallback(_do_retrieve)
704         return d
705
706     def test_corrupt_all_verbyte(self):
707         # when the version byte is not 0, we hit an assertion error in
708         # unpack_share().
709         return self._test_corrupt_all(0, "AssertionError")
710
711     def test_corrupt_all_seqnum(self):
712         # a corrupt sequence number will trigger a bad signature
713         return self._test_corrupt_all(1, "signature is invalid")
714
715     def test_corrupt_all_R(self):
716         # a corrupt root hash will trigger a bad signature
717         return self._test_corrupt_all(9, "signature is invalid")
718
719     def test_corrupt_all_IV(self):
720         # a corrupt salt/IV will trigger a bad signature
721         return self._test_corrupt_all(41, "signature is invalid")
722
723     def test_corrupt_all_k(self):
724         # a corrupt 'k' will trigger a bad signature
725         return self._test_corrupt_all(57, "signature is invalid")
726
727     def test_corrupt_all_N(self):
728         # a corrupt 'N' will trigger a bad signature
729         return self._test_corrupt_all(58, "signature is invalid")
730
731     def test_corrupt_all_segsize(self):
732         # a corrupt segsize will trigger a bad signature
733         return self._test_corrupt_all(59, "signature is invalid")
734
735     def test_corrupt_all_datalen(self):
736         # a corrupt data length will trigger a bad signature
737         return self._test_corrupt_all(67, "signature is invalid")
738
739     def test_corrupt_all_pubkey(self):
740         # a corrupt pubkey won't match the URI's fingerprint. We need to
741         # remove the pubkey from the filenode, or else it won't bother trying
742         # to update it.
743         self._fn._pubkey = None
744         return self._test_corrupt_all("pubkey",
745                                       "pubkey doesn't match fingerprint")
746
747     def test_corrupt_all_sig(self):
748         # a corrupt signature is a bad one
749         # the signature runs from about [543:799], depending upon the length
750         # of the pubkey
751         return self._test_corrupt_all("signature", "signature is invalid")
752
753     def test_corrupt_all_share_hash_chain_number(self):
754         # a corrupt share hash chain entry will show up as a bad hash. If we
755         # mangle the first byte, that will look like a bad hash number,
756         # causing an IndexError
757         return self._test_corrupt_all("share_hash_chain", "corrupt hashes")
758
759     def test_corrupt_all_share_hash_chain_hash(self):
760         # a corrupt share hash chain entry will show up as a bad hash. If we
761         # mangle a few bytes in, that will look like a bad hash.
762         return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes")
763
764     def test_corrupt_all_block_hash_tree(self):
765         return self._test_corrupt_all("block_hash_tree",
766                                       "block hash tree failure")
767
768     def test_corrupt_all_block(self):
769         return self._test_corrupt_all("share_data", "block hash tree failure")
770
771     def test_corrupt_all_encprivkey(self):
772         # a corrupted privkey won't even be noticed by the reader, only by a
773         # writer.
774         return self._test_corrupt_all("enc_privkey", None, should_succeed=True)
775
776     def test_basic_pubkey_at_end(self):
777         # we corrupt the pubkey in all but the last 'k' shares, allowing the
778         # download to succeed but forcing a bunch of retries first. Note that
779         # this is rather pessimistic: our Retrieve process will throw away
780         # the whole share if the pubkey is bad, even though the rest of the
781         # share might be good.
782
783         self._fn._pubkey = None
784         k = self._fn.get_required_shares()
785         N = self._fn.get_total_shares()
786         d = defer.succeed(None)
787         d.addCallback(corrupt, self._storage, "pubkey",
788                       shnums_to_corrupt=range(0, N-k))
789         d.addCallback(lambda res: self.make_servermap())
790         def _do_retrieve(servermap):
791             self.failUnless(servermap.problems)
792             self.failUnless("pubkey doesn't match fingerprint"
793                             in str(servermap.problems[0]))
794             ver = servermap.best_recoverable_version()
795             r = Retrieve(self._fn, servermap, ver)
796             return r.download()
797         d.addCallback(_do_retrieve)
798         d.addCallback(lambda new_contents:
799                       self.failUnlessEqual(new_contents, self.CONTENTS))
800         return d
801
802     def test_corrupt_some(self):
803         # corrupt the data of first five shares (so the servermap thinks
804         # they're good but retrieve marks them as bad), so that the
805         # MODE_READ set of 6 will be insufficient, forcing node.download to
806         # retry with more servers.
807         corrupt(None, self._storage, "share_data", range(5))
808         d = self.make_servermap()
809         def _do_retrieve(servermap):
810             ver = servermap.best_recoverable_version()
811             self.failUnless(ver)
812             return self._fn.download_best_version()
813         d.addCallback(_do_retrieve)
814         d.addCallback(lambda new_contents:
815                       self.failUnlessEqual(new_contents, self.CONTENTS))
816         return d
817
818     def test_download_fails(self):
819         corrupt(None, self._storage, "signature")
820         d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
821                             "no recoverable versions",
822                             self._fn.download_best_version)
823         return d
824
825
826 class MultipleEncodings(unittest.TestCase):
827     def setUp(self):
828         self.CONTENTS = "New contents go here"
829         num_peers = 20
830         self._client = FakeClient(num_peers)
831         self._storage = self._client._storage
832         d = self._client.create_mutable_file(self.CONTENTS)
833         def _created(node):
834             self._fn = node
835         d.addCallback(_created)
836         return d
837
838     def _encode(self, k, n, data):
839         # encode 'data' into a peerid->shares dict.
840
841         fn2 = FastMutableFileNode(self._client)
842         # init_from_uri populates _uri, _writekey, _readkey, _storage_index,
843         # and _fingerprint
844         fn = self._fn
845         fn2.init_from_uri(fn.get_uri())
846         # then we copy over other fields that are normally fetched from the
847         # existing shares
848         fn2._pubkey = fn._pubkey
849         fn2._privkey = fn._privkey
850         fn2._encprivkey = fn._encprivkey
851         fn2._current_seqnum = 0
852         fn2._current_roothash = "\x00" * 32
853         # and set the encoding parameters to something completely different
854         fn2._required_shares = k
855         fn2._total_shares = n
856
857         s = self._client._storage
858         s._peers = {} # clear existing storage
859         p2 = Publish(fn2, None)
860         d = p2.publish(data)
861         def _published(res):
862             shares = s._peers
863             s._peers = {}
864             return shares
865         d.addCallback(_published)
866         return d
867
868     def make_servermap(self, mode=MODE_READ, oldmap=None):
869         if oldmap is None:
870             oldmap = ServerMap()
871         smu = ServermapUpdater(self._fn, oldmap, mode)
872         d = smu.update()
873         return d
874
875     def test_multiple_encodings(self):
876         # we encode the same file in two different ways (3-of-10 and 4-of-9),
877         # then mix up the shares, to make sure that download survives seeing
878         # a variety of encodings. This is actually kind of tricky to set up.
879
880         contents1 = "Contents for encoding 1 (3-of-10) go here"
881         contents2 = "Contents for encoding 2 (4-of-9) go here"
882         contents3 = "Contents for encoding 3 (4-of-7) go here"
883
884         # we make a retrieval object that doesn't know what encoding
885         # parameters to use
886         fn3 = FastMutableFileNode(self._client)
887         fn3.init_from_uri(self._fn.get_uri())
888
889         # now we upload a file through fn1, and grab its shares
890         d = self._encode(3, 10, contents1)
891         def _encoded_1(shares):
892             self._shares1 = shares
893         d.addCallback(_encoded_1)
894         d.addCallback(lambda res: self._encode(4, 9, contents2))
895         def _encoded_2(shares):
896             self._shares2 = shares
897         d.addCallback(_encoded_2)
898         d.addCallback(lambda res: self._encode(4, 7, contents3))
899         def _encoded_3(shares):
900             self._shares3 = shares
901         d.addCallback(_encoded_3)
902
903         def _merge(res):
904             log.msg("merging sharelists")
905             # we merge the shares from the two sets, leaving each shnum in
906             # its original location, but using a share from set1 or set2
907             # according to the following sequence:
908             #
909             #  4-of-9  a  s2
910             #  4-of-9  b  s2
911             #  4-of-7  c   s3
912             #  4-of-9  d  s2
913             #  3-of-9  e s1
914             #  3-of-9  f s1
915             #  3-of-9  g s1
916             #  4-of-9  h  s2
917             #
918             # so that neither form can be recovered until fetch [f], at which
919             # point version-s1 (the 3-of-10 form) should be recoverable. If
920             # the implementation latches on to the first version it sees,
921             # then s2 will be recoverable at fetch [g].
922
923             # Later, when we implement code that handles multiple versions,
924             # we can use this framework to assert that all recoverable
925             # versions are retrieved, and test that 'epsilon' does its job
926
927             places = [2, 2, 3, 2, 1, 1, 1, 2]
928
929             sharemap = {}
930
931             for i,peerid in enumerate(self._client._peerids):
932                 peerid_s = shortnodeid_b2a(peerid)
933                 for shnum in self._shares1.get(peerid, {}):
934                     if shnum < len(places):
935                         which = places[shnum]
936                     else:
937                         which = "x"
938                     self._client._storage._peers[peerid] = peers = {}
939                     in_1 = shnum in self._shares1[peerid]
940                     in_2 = shnum in self._shares2.get(peerid, {})
941                     in_3 = shnum in self._shares3.get(peerid, {})
942                     #print peerid_s, shnum, which, in_1, in_2, in_3
943                     if which == 1:
944                         if in_1:
945                             peers[shnum] = self._shares1[peerid][shnum]
946                             sharemap[shnum] = peerid
947                     elif which == 2:
948                         if in_2:
949                             peers[shnum] = self._shares2[peerid][shnum]
950                             sharemap[shnum] = peerid
951                     elif which == 3:
952                         if in_3:
953                             peers[shnum] = self._shares3[peerid][shnum]
954                             sharemap[shnum] = peerid
955
956             # we don't bother placing any other shares
957             # now sort the sequence so that share 0 is returned first
958             new_sequence = [sharemap[shnum]
959                             for shnum in sorted(sharemap.keys())]
960             self._client._storage._sequence = new_sequence
961             log.msg("merge done")
962         d.addCallback(_merge)
963         d.addCallback(lambda res: fn3.download_best_version())
964         def _retrieved(new_contents):
965             # the current specified behavior is "first version recoverable"
966             self.failUnlessEqual(new_contents, contents1)
967         d.addCallback(_retrieved)
968         return d
969
970
971 class Utils(unittest.TestCase):
972     def test_dict_of_sets(self):
973         ds = DictOfSets()
974         ds.add(1, "a")
975         ds.add(2, "b")
976         ds.add(2, "b")
977         ds.add(2, "c")
978         self.failUnlessEqual(ds[1], set(["a"]))
979         self.failUnlessEqual(ds[2], set(["b", "c"]))
980         ds.discard(3, "d") # should not raise an exception
981         ds.discard(2, "b")
982         self.failUnlessEqual(ds[2], set(["c"]))
983         ds.discard(2, "c")
984         self.failIf(2 in ds)
985
986     def _do_inside(self, c, x_start, x_length, y_start, y_length):
987         # we compare this against sets of integers
988         x = set(range(x_start, x_start+x_length))
989         y = set(range(y_start, y_start+y_length))
990         should_be_inside = x.issubset(y)
991         self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
992                                                          y_start, y_length),
993                              str((x_start, x_length, y_start, y_length)))
994
995     def test_cache_inside(self):
996         c = ResponseCache()
997         x_start = 10
998         x_length = 5
999         for y_start in range(8, 17):
1000             for y_length in range(8):
1001                 self._do_inside(c, x_start, x_length, y_start, y_length)
1002
1003     def _do_overlap(self, c, x_start, x_length, y_start, y_length):
1004         # we compare this against sets of integers
1005         x = set(range(x_start, x_start+x_length))
1006         y = set(range(y_start, y_start+y_length))
1007         overlap = bool(x.intersection(y))
1008         self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
1009                                                       y_start, y_length),
1010                              str((x_start, x_length, y_start, y_length)))
1011
1012     def test_cache_overlap(self):
1013         c = ResponseCache()
1014         x_start = 10
1015         x_length = 5
1016         for y_start in range(8, 17):
1017             for y_length in range(8):
1018                 self._do_overlap(c, x_start, x_length, y_start, y_length)
1019
1020     def test_cache(self):
1021         c = ResponseCache()
1022         # xdata = base62.b2a(os.urandom(100))[:100]
1023         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
1024         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
1025         nope = (None, None)
1026         c.add("v1", 1, 0, xdata, "time0")
1027         c.add("v1", 1, 2000, ydata, "time1")
1028         self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
1029         self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
1030         self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
1031         self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
1032         self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
1033         self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
1034         self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
1035         self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
1036         self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
1037         self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
1038         self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
1039         self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
1040         self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
1041         self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
1042         self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
1043         self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
1044         self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
1045         self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
1046
1047         # optional: join fragments
1048         c = ResponseCache()
1049         c.add("v1", 1, 0, xdata[:10], "time0")
1050         c.add("v1", 1, 10, xdata[10:20], "time1")
1051         #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
1052