]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_encode.py
verify hash chains on incoming blocks
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_encode.py
1 #! /usr/bin/env python
2
3 from twisted.trial import unittest
4 from twisted.internet import defer
5 from foolscap import eventual
6 from allmydata import encode, download
7 from allmydata.uri import pack_uri
8 from cStringIO import StringIO
9
10 class FakePeer:
11     def __init__(self, mode="good"):
12         self.ss = FakeStorageServer(mode)
13
14     def callRemote(self, methname, *args, **kwargs):
15         def _call():
16             meth = getattr(self, methname)
17             return meth(*args, **kwargs)
18         return defer.maybeDeferred(_call)
19
20     def get_service(self, sname):
21         assert sname == "storageserver"
22         return self.ss
23
24 class FakeStorageServer:
25     def __init__(self, mode):
26         self.mode = mode
27     def callRemote(self, methname, *args, **kwargs):
28         def _call():
29             meth = getattr(self, methname)
30             return meth(*args, **kwargs)
31         d = eventual.fireEventually()
32         d.addCallback(lambda res: _call())
33         return d
34     def allocate_buckets(self, verifierid, sharenums, shareize, blocksize, canary):
35         if self.mode == "full":
36             return (set(), {},)
37         elif self.mode == "already got them":
38             return (set(sharenums), {},)
39         else:
40             return (set(), dict([(shnum, FakeBucketWriter(),) for shnum in sharenums]),)
41
42 class FakeBucketWriter:
43     def __init__(self):
44         self.blocks = {}
45         self.block_hashes = None
46         self.share_hashes = None
47         self.closed = False
48
49     def callRemote(self, methname, *args, **kwargs):
50         def _call():
51             meth = getattr(self, methname)
52             return meth(*args, **kwargs)
53         return defer.maybeDeferred(_call)
54
55     def put_block(self, segmentnum, data):
56         assert not self.closed
57         assert segmentnum not in self.blocks
58         self.blocks[segmentnum] = data
59     
60     def put_block_hashes(self, blockhashes):
61         assert not self.closed
62         assert self.block_hashes is None
63         self.block_hashes = blockhashes
64         
65     def put_share_hashes(self, sharehashes):
66         assert not self.closed
67         assert self.share_hashes is None
68         self.share_hashes = sharehashes
69
70     def close(self):
71         assert not self.closed
72         self.closed = True
73
74
75     def get_block(self, blocknum):
76         assert isinstance(blocknum, int)
77         return self.blocks[blocknum]
78
79     def get_block_hashes(self):
80         return self.block_hashes
81     def get_share_hashes(self):
82         return self.share_hashes
83
84
85 class Encode(unittest.TestCase):
86     def test_send(self):
87         e = encode.Encoder()
88         data = "happy happy joy joy" * 4
89         e.setup(StringIO(data))
90         NUM_SHARES = 100
91         assert e.num_shares == NUM_SHARES # else we'll be completely confused
92         e.segment_size = 25 # force use of multiple segments
93         e.setup_codec() # need to rebuild the codec for that change
94         NUM_SEGMENTS = 4
95         assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
96         shareholders = {}
97         all_shareholders = []
98         for shnum in range(NUM_SHARES):
99             peer = FakeBucketWriter()
100             shareholders[shnum] = peer
101             all_shareholders.append(peer)
102         e.set_shareholders(shareholders)
103         d = e.start()
104         def _check(roothash):
105             self.failUnless(isinstance(roothash, str))
106             self.failUnlessEqual(len(roothash), 32)
107             for i,peer in enumerate(all_shareholders):
108                 self.failUnless(peer.closed)
109                 self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
110                 #self.failUnlessEqual(len(peer.block_hashes), NUM_SEGMENTS)
111                 # that isn't true: each peer gets a full tree, so it's more
112                 # like 2n-1 but with rounding to a power of two
113                 for h in peer.block_hashes:
114                     self.failUnlessEqual(len(h), 32)
115                 #self.failUnlessEqual(len(peer.share_hashes), NUM_SHARES)
116                 # that isn't true: each peer only gets the chain they need
117                 for (hashnum, h) in peer.share_hashes:
118                     self.failUnless(isinstance(hashnum, int))
119                     self.failUnlessEqual(len(h), 32)
120         d.addCallback(_check)
121
122         return d
123
124 class Roundtrip(unittest.TestCase):
125     def send_and_recover(self, NUM_SHARES, NUM_PEERS, NUM_SEGMENTS=4):
126         e = encode.Encoder()
127         data = "happy happy joy joy" * 4
128         e.setup(StringIO(data))
129
130         assert e.num_shares == NUM_SHARES # else we'll be completely confused
131         e.segment_size = 25 # force use of multiple segments
132         e.setup_codec() # need to rebuild the codec for that change
133
134         assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
135         shareholders = {}
136         all_shareholders = []
137         all_peers = []
138         for i in range(NUM_PEERS):
139             all_peers.append(FakeBucketWriter())
140         for shnum in range(NUM_SHARES):
141             peer = all_peers[shnum % NUM_PEERS]
142             shareholders[shnum] = peer
143             all_shareholders.append(peer)
144         e.set_shareholders(shareholders)
145         d = e.start()
146         def _uploaded(roothash):
147             URI = pack_uri(e._codec.get_encoder_type(),
148                            e._codec.get_serialized_params(),
149                            e._tail_codec.get_serialized_params(),
150                            "V" * 20,
151                            roothash,
152                            e.required_shares,
153                            e.num_shares,
154                            e.file_size,
155                            e.segment_size)
156             client = None
157             target = download.Data()
158             fd = download.FileDownloader(client, URI, target)
159             for shnum in range(NUM_SHARES):
160                 bucket = all_shareholders[shnum]
161                 fd.add_share_bucket(shnum, bucket)
162             fd._got_all_shareholders(None)
163             d2 = fd._download_all_segments(None)
164             d2.addCallback(fd._done)
165             return d2
166         d.addCallback(_uploaded)
167         def _downloaded(newdata):
168             self.failUnless(newdata == data)
169         d.addCallback(_downloaded)
170
171         return d
172
173     def test_one_share_per_peer(self):
174         return self.send_and_recover(100, 100)
175