]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_util.py
move testutil into test/common_util.py, since it doesn't count as 'code under test...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_util.py
1
2 def foo(): pass # keep the line number constant
3
4 import os, time
5 from twisted.trial import unittest
6 from twisted.internet import defer, reactor
7 from twisted.python import failure
8
9 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
10 from allmydata.util import assertutil, fileutil, deferredutil
11 from allmydata.util import limiter, time_format, pollmixin
12
13 class Base32(unittest.TestCase):
14     def test_b2a_matches_Pythons(self):
15         import base64
16         y = "\x12\x34\x45\x67\x89\x0a\xbc\xde\xf0"
17         x = base64.b32encode(y)
18         while x and x[-1] == '=':
19             x = x[:-1]
20         x = x.lower()
21         self.failUnlessEqual(base32.b2a(y), x)
22     def test_b2a(self):
23         self.failUnlessEqual(base32.b2a("\x12\x34"), "ci2a")
24     def test_b2a_or_none(self):
25         self.failUnlessEqual(base32.b2a_or_none(None), None)
26         self.failUnlessEqual(base32.b2a_or_none("\x12\x34"), "ci2a")
27     def test_a2b(self):
28         self.failUnlessEqual(base32.a2b("ci2a"), "\x12\x34")
29         self.failUnlessRaises(AssertionError, base32.a2b, "b0gus")
30
31 class IDLib(unittest.TestCase):
32     def test_nodeid_b2a(self):
33         self.failUnlessEqual(idlib.nodeid_b2a("\x00"*20), "a"*32)
34
35 class NoArgumentException(Exception):
36     def __init__(self):
37         pass
38
39 class HumanReadable(unittest.TestCase):
40     def test_repr(self):
41         hr = humanreadable.hr
42         self.failUnlessEqual(hr(foo), "<foo() at test_util.py:2>")
43         self.failUnlessEqual(hr(self.test_repr),
44                              "<bound method HumanReadable.test_repr of <allmydata.test.test_util.HumanReadable testMethod=test_repr>>")
45         self.failUnlessEqual(hr(1L), "1")
46         self.failUnlessEqual(hr(10**40),
47                              "100000000000000000...000000000000000000")
48         self.failUnlessEqual(hr(self), "<allmydata.test.test_util.HumanReadable testMethod=test_repr>")
49         self.failUnlessEqual(hr([1,2]), "[1, 2]")
50         self.failUnlessEqual(hr({1:2}), "{1:2}")
51         try:
52             raise RuntimeError
53         except Exception, e:
54             self.failUnless(
55                 hr(e) == "<RuntimeError: ()>" # python-2.4
56                 or hr(e) == "RuntimeError()") # python-2.5
57         try:
58             raise RuntimeError("oops")
59         except Exception, e:
60             self.failUnless(
61                 hr(e) == "<RuntimeError: 'oops'>" # python-2.4
62                 or hr(e) == "RuntimeError('oops',)") # python-2.5
63         try:
64             raise NoArgumentException
65         except Exception, e:
66             self.failUnless(
67                 hr(e) == "<NoArgumentException>" # python-2.4
68                 or hr(e) == "NoArgumentException()") # python-2.5
69
70
71 class MyList(list):
72     pass
73
74 class Math(unittest.TestCase):
75     def test_div_ceil(self):
76         f = mathutil.div_ceil
77         self.failUnlessEqual(f(0, 1), 0)
78         self.failUnlessEqual(f(0, 2), 0)
79         self.failUnlessEqual(f(0, 3), 0)
80         self.failUnlessEqual(f(1, 3), 1)
81         self.failUnlessEqual(f(2, 3), 1)
82         self.failUnlessEqual(f(3, 3), 1)
83         self.failUnlessEqual(f(4, 3), 2)
84         self.failUnlessEqual(f(5, 3), 2)
85         self.failUnlessEqual(f(6, 3), 2)
86         self.failUnlessEqual(f(7, 3), 3)
87
88     def test_next_multiple(self):
89         f = mathutil.next_multiple
90         self.failUnlessEqual(f(5, 1), 5)
91         self.failUnlessEqual(f(5, 2), 6)
92         self.failUnlessEqual(f(5, 3), 6)
93         self.failUnlessEqual(f(5, 4), 8)
94         self.failUnlessEqual(f(5, 5), 5)
95         self.failUnlessEqual(f(5, 6), 6)
96         self.failUnlessEqual(f(32, 1), 32)
97         self.failUnlessEqual(f(32, 2), 32)
98         self.failUnlessEqual(f(32, 3), 33)
99         self.failUnlessEqual(f(32, 4), 32)
100         self.failUnlessEqual(f(32, 5), 35)
101         self.failUnlessEqual(f(32, 6), 36)
102         self.failUnlessEqual(f(32, 7), 35)
103         self.failUnlessEqual(f(32, 8), 32)
104         self.failUnlessEqual(f(32, 9), 36)
105         self.failUnlessEqual(f(32, 10), 40)
106         self.failUnlessEqual(f(32, 11), 33)
107         self.failUnlessEqual(f(32, 12), 36)
108         self.failUnlessEqual(f(32, 13), 39)
109         self.failUnlessEqual(f(32, 14), 42)
110         self.failUnlessEqual(f(32, 15), 45)
111         self.failUnlessEqual(f(32, 16), 32)
112         self.failUnlessEqual(f(32, 17), 34)
113         self.failUnlessEqual(f(32, 18), 36)
114         self.failUnlessEqual(f(32, 589), 589)
115
116     def test_pad_size(self):
117         f = mathutil.pad_size
118         self.failUnlessEqual(f(0, 4), 0)
119         self.failUnlessEqual(f(1, 4), 3)
120         self.failUnlessEqual(f(2, 4), 2)
121         self.failUnlessEqual(f(3, 4), 1)
122         self.failUnlessEqual(f(4, 4), 0)
123         self.failUnlessEqual(f(5, 4), 3)
124
125     def test_is_power_of_k(self):
126         f = mathutil.is_power_of_k
127         for i in range(1, 100):
128             if i in (1, 2, 4, 8, 16, 32, 64):
129                 self.failUnless(f(i, 2), "but %d *is* a power of 2" % i)
130             else:
131                 self.failIf(f(i, 2), "but %d is *not* a power of 2" % i)
132         for i in range(1, 100):
133             if i in (1, 3, 9, 27, 81):
134                 self.failUnless(f(i, 3), "but %d *is* a power of 3" % i)
135             else:
136                 self.failIf(f(i, 3), "but %d is *not* a power of 3" % i)
137
138     def test_next_power_of_k(self):
139         f = mathutil.next_power_of_k
140         self.failUnlessEqual(f(0,2), 1)
141         self.failUnlessEqual(f(1,2), 1)
142         self.failUnlessEqual(f(2,2), 2)
143         self.failUnlessEqual(f(3,2), 4)
144         self.failUnlessEqual(f(4,2), 4)
145         for i in range(5, 8): self.failUnlessEqual(f(i,2), 8, "%d" % i)
146         for i in range(9, 16): self.failUnlessEqual(f(i,2), 16, "%d" % i)
147         for i in range(17, 32): self.failUnlessEqual(f(i,2), 32, "%d" % i)
148         for i in range(33, 64): self.failUnlessEqual(f(i,2), 64, "%d" % i)
149         for i in range(65, 100): self.failUnlessEqual(f(i,2), 128, "%d" % i)
150
151         self.failUnlessEqual(f(0,3), 1)
152         self.failUnlessEqual(f(1,3), 1)
153         self.failUnlessEqual(f(2,3), 3)
154         self.failUnlessEqual(f(3,3), 3)
155         for i in range(4, 9): self.failUnlessEqual(f(i,3), 9, "%d" % i)
156         for i in range(10, 27): self.failUnlessEqual(f(i,3), 27, "%d" % i)
157         for i in range(28, 81): self.failUnlessEqual(f(i,3), 81, "%d" % i)
158         for i in range(82, 200): self.failUnlessEqual(f(i,3), 243, "%d" % i)
159
160     def test_ave(self):
161         f = mathutil.ave
162         self.failUnlessEqual(f([1,2,3]), 2)
163         self.failUnlessEqual(f([0,0,0,4]), 1)
164         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
165
166
167 class Asserts(unittest.TestCase):
168     def should_assert(self, func, *args, **kwargs):
169         try:
170             func(*args, **kwargs)
171         except AssertionError, e:
172             return str(e)
173         except Exception, e:
174             self.fail("assert failed with non-AssertionError: %s" % e)
175         self.fail("assert was not caught")
176
177     def should_not_assert(self, func, *args, **kwargs):
178         if "re" in kwargs:
179             regexp = kwargs["re"]
180             del kwargs["re"]
181         try:
182             func(*args, **kwargs)
183         except AssertionError, e:
184             self.fail("assertion fired when it should not have: %s" % e)
185         except Exception, e:
186             self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
187         return # we're happy
188
189
190     def test_assert(self):
191         f = assertutil._assert
192         self.should_assert(f)
193         self.should_assert(f, False)
194         self.should_not_assert(f, True)
195
196         m = self.should_assert(f, False, "message")
197         self.failUnlessEqual(m, "'message' <type 'str'>", m)
198         m = self.should_assert(f, False, "message1", othermsg=12)
199         self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
200         m = self.should_assert(f, False, othermsg="message2")
201         self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
202
203     def test_precondition(self):
204         f = assertutil.precondition
205         self.should_assert(f)
206         self.should_assert(f, False)
207         self.should_not_assert(f, True)
208
209         m = self.should_assert(f, False, "message")
210         self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
211         m = self.should_assert(f, False, "message1", othermsg=12)
212         self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
213         m = self.should_assert(f, False, othermsg="message2")
214         self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
215
216     def test_postcondition(self):
217         f = assertutil.postcondition
218         self.should_assert(f)
219         self.should_assert(f, False)
220         self.should_not_assert(f, True)
221
222         m = self.should_assert(f, False, "message")
223         self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
224         m = self.should_assert(f, False, "message1", othermsg=12)
225         self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
226         m = self.should_assert(f, False, othermsg="message2")
227         self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
228
229 class FileUtil(unittest.TestCase):
230     def mkdir(self, basedir, path, mode=0777):
231         fn = os.path.join(basedir, path)
232         fileutil.make_dirs(fn, mode)
233
234     def touch(self, basedir, path, mode=None, data="touch\n"):
235         fn = os.path.join(basedir, path)
236         f = open(fn, "w")
237         f.write(data)
238         f.close()
239         if mode is not None:
240             os.chmod(fn, mode)
241
242     def test_rm_dir(self):
243         basedir = "util/FileUtil/test_rm_dir"
244         fileutil.make_dirs(basedir)
245         # create it again to test idempotency
246         fileutil.make_dirs(basedir)
247         d = os.path.join(basedir, "doomed")
248         self.mkdir(d, "a/b")
249         self.touch(d, "a/b/1.txt")
250         self.touch(d, "a/b/2.txt", 0444)
251         self.touch(d, "a/b/3.txt", 0)
252         self.mkdir(d, "a/c")
253         self.touch(d, "a/c/1.txt")
254         self.touch(d, "a/c/2.txt", 0444)
255         self.touch(d, "a/c/3.txt", 0)
256         os.chmod(os.path.join(d, "a/c"), 0444)
257         self.mkdir(d, "a/d")
258         self.touch(d, "a/d/1.txt")
259         self.touch(d, "a/d/2.txt", 0444)
260         self.touch(d, "a/d/3.txt", 0)
261         os.chmod(os.path.join(d, "a/d"), 0)
262
263         fileutil.rm_dir(d)
264         self.failIf(os.path.exists(d))
265         # remove it again to test idempotency
266         fileutil.rm_dir(d)
267
268     def test_remove_if_possible(self):
269         basedir = "util/FileUtil/test_remove_if_possible"
270         fileutil.make_dirs(basedir)
271         self.touch(basedir, "here")
272         fn = os.path.join(basedir, "here")
273         fileutil.remove_if_possible(fn)
274         self.failIf(os.path.exists(fn))
275         fileutil.remove_if_possible(fn) # should be idempotent
276         fileutil.rm_dir(basedir)
277         fileutil.remove_if_possible(fn) # should survive errors
278
279     def test_open_or_create(self):
280         basedir = "util/FileUtil/test_open_or_create"
281         fileutil.make_dirs(basedir)
282         fn = os.path.join(basedir, "here")
283         f = fileutil.open_or_create(fn)
284         f.write("stuff.")
285         f.close()
286         f = fileutil.open_or_create(fn)
287         f.seek(0, 2)
288         f.write("more.")
289         f.close()
290         f = open(fn, "r")
291         data = f.read()
292         f.close()
293         self.failUnlessEqual(data, "stuff.more.")
294
295     def test_NamedTemporaryDirectory(self):
296         basedir = "util/FileUtil/test_NamedTemporaryDirectory"
297         fileutil.make_dirs(basedir)
298         td = fileutil.NamedTemporaryDirectory(dir=basedir)
299         name = td.name
300         self.failUnless(basedir in name)
301         self.failUnless(basedir in repr(td))
302         self.failUnless(os.path.isdir(name))
303         del td
304         # it is conceivable that we need to force gc here, but I'm not sure
305         self.failIf(os.path.isdir(name))
306
307     def test_rename(self):
308         basedir = "util/FileUtil/test_rename"
309         fileutil.make_dirs(basedir)
310         self.touch(basedir, "here")
311         fn = os.path.join(basedir, "here")
312         fn2 = os.path.join(basedir, "there")
313         fileutil.rename(fn, fn2)
314         self.failIf(os.path.exists(fn))
315         self.failUnless(os.path.exists(fn2))
316
317     def test_du(self):
318         basedir = "util/FileUtil/test_du"
319         fileutil.make_dirs(basedir)
320         d = os.path.join(basedir, "space-consuming")
321         self.mkdir(d, "a/b")
322         self.touch(d, "a/b/1.txt", data="a"*10)
323         self.touch(d, "a/b/2.txt", data="b"*11)
324         self.mkdir(d, "a/c")
325         self.touch(d, "a/c/1.txt", data="c"*12)
326         self.touch(d, "a/c/2.txt", data="d"*13)
327
328         used = fileutil.du(basedir)
329         self.failUnlessEqual(10+11+12+13, used)
330
331 class PollMixinTests(unittest.TestCase):
332     def setUp(self):
333         self.pm = pollmixin.PollMixin()
334
335     def test_PollMixin_True(self):
336         d = self.pm.poll(check_f=lambda : True,
337                          pollinterval=0.1)
338         return d
339
340     def test_PollMixin_False_then_True(self):
341         i = iter([False, True])
342         d = self.pm.poll(check_f=i.next,
343                          pollinterval=0.1)
344         return d
345
346     def test_timeout(self):
347         d = self.pm.poll(check_f=lambda: False,
348                          pollinterval=0.01,
349                          timeout=1)
350         def _suc(res):
351             self.fail("poll should have failed, not returned %s" % (res,))
352         def _err(f):
353             f.trap(pollmixin.TimeoutError)
354             return None # success
355         d.addCallbacks(_suc, _err)
356         return d
357
358 class DeferredUtilTests(unittest.TestCase):
359     def test_success(self):
360         d1, d2 = defer.Deferred(), defer.Deferred()
361         good = []
362         bad = []
363         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
364         dlss.addCallbacks(good.append, bad.append)
365         d1.callback(1)
366         d2.callback(2)
367         self.failUnlessEqual(good, [[1,2]])
368         self.failUnlessEqual(bad, [])
369
370     def test_failure(self):
371         d1, d2 = defer.Deferred(), defer.Deferred()
372         good = []
373         bad = []
374         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
375         dlss.addCallbacks(good.append, bad.append)
376         d1.addErrback(lambda _ignore: None)
377         d2.addErrback(lambda _ignore: None)
378         d1.callback(1)
379         d2.errback(RuntimeError())
380         self.failUnlessEqual(good, [])
381         self.failUnlessEqual(len(bad), 1)
382         f = bad[0]
383         self.failUnless(isinstance(f, failure.Failure))
384         self.failUnless(f.check(RuntimeError))
385
386 class HashUtilTests(unittest.TestCase):
387
388     def test_random_key(self):
389         k = hashutil.random_key()
390         self.failUnlessEqual(len(k), hashutil.KEYLEN)
391
392     def test_sha256d(self):
393         h1 = hashutil.tagged_hash("tag1", "value")
394         h2 = hashutil.tagged_hasher("tag1")
395         h2.update("value")
396         h2a = h2.digest()
397         h2b = h2.digest()
398         self.failUnlessEqual(h1, h2a)
399         self.failUnlessEqual(h2a, h2b)
400
401     def test_sha256d_truncated(self):
402         h1 = hashutil.tagged_hash("tag1", "value", 16)
403         h2 = hashutil.tagged_hasher("tag1", 16)
404         h2.update("value")
405         h2 = h2.digest()
406         self.failUnlessEqual(len(h1), 16)
407         self.failUnlessEqual(len(h2), 16)
408         self.failUnlessEqual(h1, h2)
409
410     def test_chk(self):
411         h1 = hashutil.convergence_hash(3, 10, 1000, "data", "secret")
412         h2 = hashutil.convergence_hasher(3, 10, 1000, "secret")
413         h2.update("data")
414         h2 = h2.digest()
415         self.failUnlessEqual(h1, h2)
416
417     def test_hashers(self):
418         h1 = hashutil.block_hash("foo")
419         h2 = hashutil.block_hasher()
420         h2.update("foo")
421         self.failUnlessEqual(h1, h2.digest())
422
423         h1 = hashutil.uri_extension_hash("foo")
424         h2 = hashutil.uri_extension_hasher()
425         h2.update("foo")
426         self.failUnlessEqual(h1, h2.digest())
427
428         h1 = hashutil.plaintext_hash("foo")
429         h2 = hashutil.plaintext_hasher()
430         h2.update("foo")
431         self.failUnlessEqual(h1, h2.digest())
432
433         h1 = hashutil.crypttext_hash("foo")
434         h2 = hashutil.crypttext_hasher()
435         h2.update("foo")
436         self.failUnlessEqual(h1, h2.digest())
437
438         h1 = hashutil.crypttext_segment_hash("foo")
439         h2 = hashutil.crypttext_segment_hasher()
440         h2.update("foo")
441         self.failUnlessEqual(h1, h2.digest())
442
443         h1 = hashutil.plaintext_segment_hash("foo")
444         h2 = hashutil.plaintext_segment_hasher()
445         h2.update("foo")
446         self.failUnlessEqual(h1, h2.digest())
447
448 class Limiter(unittest.TestCase):
449     def job(self, i, foo):
450         self.calls.append( (i, foo) )
451         self.simultaneous += 1
452         self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
453         d = defer.Deferred()
454         def _done():
455             self.simultaneous -= 1
456             d.callback("done %d" % i)
457         reactor.callLater(1.0, _done)
458         return d
459
460     def bad_job(self, i, foo):
461         raise RuntimeError("bad_job %d" % i)
462
463     def test_limiter(self):
464         self.calls = []
465         self.simultaneous = 0
466         self.peak_simultaneous = 0
467         l = limiter.ConcurrencyLimiter()
468         dl = []
469         for i in range(20):
470             dl.append(l.add(self.job, i, foo=str(i)))
471         d = defer.DeferredList(dl, fireOnOneErrback=True)
472         def _done(res):
473             self.failUnlessEqual(self.simultaneous, 0)
474             self.failUnless(self.peak_simultaneous <= 10)
475             self.failUnlessEqual(len(self.calls), 20)
476             for i in range(20):
477                 self.failUnless( (i, str(i)) in self.calls)
478         d.addCallback(_done)
479         return d
480
481     def test_errors(self):
482         self.calls = []
483         self.simultaneous = 0
484         self.peak_simultaneous = 0
485         l = limiter.ConcurrencyLimiter()
486         dl = []
487         for i in range(20):
488             dl.append(l.add(self.job, i, foo=str(i)))
489         d2 = l.add(self.bad_job, 21, "21")
490         d = defer.DeferredList(dl, fireOnOneErrback=True)
491         def _most_done(res):
492             results = []
493             for (success, result) in res:
494                 self.failUnlessEqual(success, True)
495                 results.append(result)
496             results.sort()
497             expected_results = ["done %d" % i for i in range(20)]
498             expected_results.sort()
499             self.failUnlessEqual(results, expected_results)
500             self.failUnless(self.peak_simultaneous <= 10)
501             self.failUnlessEqual(len(self.calls), 20)
502             for i in range(20):
503                 self.failUnless( (i, str(i)) in self.calls)
504             def _good(res):
505                 self.fail("should have failed, not got %s" % (res,))
506             def _err(f):
507                 f.trap(RuntimeError)
508                 self.failUnless("bad_job 21" in str(f))
509             d2.addCallbacks(_good, _err)
510             return d2
511         d.addCallback(_most_done)
512         def _all_done(res):
513             self.failUnlessEqual(self.simultaneous, 0)
514             self.failUnless(self.peak_simultaneous <= 10)
515             self.failUnlessEqual(len(self.calls), 20)
516             for i in range(20):
517                 self.failUnless( (i, str(i)) in self.calls)
518         d.addCallback(_all_done)
519         return d
520
521 class TimeFormat(unittest.TestCase):
522     def test_epoch(self):
523         s = time_format.iso_utc_time_to_localseconds("1970-01-01T00:00:01")
524         self.failUnlessEqual(s, 1.0)
525         s = time_format.iso_utc_time_to_localseconds("1970-01-01_00:00:01")
526         self.failUnlessEqual(s, 1.0)
527         s = time_format.iso_utc_time_to_localseconds("1970-01-01 00:00:01")
528         self.failUnlessEqual(s, 1.0)
529
530         self.failUnlessEqual(time_format.iso_utc(1.0), "1970-01-01_00:00:01")
531         self.failUnlessEqual(time_format.iso_utc(1.0, sep=" "),
532                              "1970-01-01 00:00:01")
533         now = time.time()
534         def my_time():
535             return 1.0
536         self.failUnlessEqual(time_format.iso_utc(t=my_time),
537                              "1970-01-01_00:00:01")
538         e = self.failUnlessRaises(ValueError,
539                                   time_format.iso_utc_time_to_localseconds,
540                                   "invalid timestring")
541         self.failUnless("not a complete ISO8601 timestamp" in str(e))
542         s = time_format.iso_utc_time_to_localseconds("1970-01-01_00:00:01.500")
543         self.failUnlessEqual(s, 1.5)