]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_util.py
0db1a89b52bc1b38367205c94584d3798f8f4f71
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_util.py
1
2 def foo(): pass # keep the line number constant
3
4 import os, time, sys
5 from StringIO import StringIO
6 from twisted.trial import unittest
7 from twisted.internet import defer, reactor
8 from twisted.python.failure import Failure
9 from twisted.python import log
10 from pycryptopp.hash.sha256 import SHA256 as _hash
11
12 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
13 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
14 from allmydata.util import limiter, time_format, pollmixin, cachedir
15 from allmydata.util import statistics, dictutil, pipeline
16 from allmydata.util import log as tahoe_log
17 from allmydata.util.spans import Spans, overlap, DataSpans
18
19 class Base32(unittest.TestCase):
20     def test_b2a_matches_Pythons(self):
21         import base64
22         y = "\x12\x34\x45\x67\x89\x0a\xbc\xde\xf0"
23         x = base64.b32encode(y)
24         while x and x[-1] == '=':
25             x = x[:-1]
26         x = x.lower()
27         self.failUnlessEqual(base32.b2a(y), x)
28     def test_b2a(self):
29         self.failUnlessEqual(base32.b2a("\x12\x34"), "ci2a")
30     def test_b2a_or_none(self):
31         self.failUnlessEqual(base32.b2a_or_none(None), None)
32         self.failUnlessEqual(base32.b2a_or_none("\x12\x34"), "ci2a")
33     def test_a2b(self):
34         self.failUnlessEqual(base32.a2b("ci2a"), "\x12\x34")
35         self.failUnlessRaises(AssertionError, base32.a2b, "b0gus")
36
37 class IDLib(unittest.TestCase):
38     def test_nodeid_b2a(self):
39         self.failUnlessEqual(idlib.nodeid_b2a("\x00"*20), "a"*32)
40
41 class NoArgumentException(Exception):
42     def __init__(self):
43         pass
44
45 class HumanReadable(unittest.TestCase):
46     def test_repr(self):
47         hr = humanreadable.hr
48         self.failUnlessEqual(hr(foo), "<foo() at test_util.py:2>")
49         self.failUnlessEqual(hr(self.test_repr),
50                              "<bound method HumanReadable.test_repr of <allmydata.test.test_util.HumanReadable testMethod=test_repr>>")
51         self.failUnlessEqual(hr(1L), "1")
52         self.failUnlessEqual(hr(10**40),
53                              "100000000000000000...000000000000000000")
54         self.failUnlessEqual(hr(self), "<allmydata.test.test_util.HumanReadable testMethod=test_repr>")
55         self.failUnlessEqual(hr([1,2]), "[1, 2]")
56         self.failUnlessEqual(hr({1:2}), "{1:2}")
57         try:
58             raise ValueError
59         except Exception, e:
60             self.failUnless(
61                 hr(e) == "<ValueError: ()>" # python-2.4
62                 or hr(e) == "ValueError()") # python-2.5
63         try:
64             raise ValueError("oops")
65         except Exception, e:
66             self.failUnless(
67                 hr(e) == "<ValueError: 'oops'>" # python-2.4
68                 or hr(e) == "ValueError('oops',)") # python-2.5
69         try:
70             raise NoArgumentException
71         except Exception, e:
72             self.failUnless(
73                 hr(e) == "<NoArgumentException>" # python-2.4
74                 or hr(e) == "NoArgumentException()") # python-2.5
75
76
77 class MyList(list):
78     pass
79
80 class Math(unittest.TestCase):
81     def test_div_ceil(self):
82         f = mathutil.div_ceil
83         self.failUnlessEqual(f(0, 1), 0)
84         self.failUnlessEqual(f(0, 2), 0)
85         self.failUnlessEqual(f(0, 3), 0)
86         self.failUnlessEqual(f(1, 3), 1)
87         self.failUnlessEqual(f(2, 3), 1)
88         self.failUnlessEqual(f(3, 3), 1)
89         self.failUnlessEqual(f(4, 3), 2)
90         self.failUnlessEqual(f(5, 3), 2)
91         self.failUnlessEqual(f(6, 3), 2)
92         self.failUnlessEqual(f(7, 3), 3)
93
94     def test_next_multiple(self):
95         f = mathutil.next_multiple
96         self.failUnlessEqual(f(5, 1), 5)
97         self.failUnlessEqual(f(5, 2), 6)
98         self.failUnlessEqual(f(5, 3), 6)
99         self.failUnlessEqual(f(5, 4), 8)
100         self.failUnlessEqual(f(5, 5), 5)
101         self.failUnlessEqual(f(5, 6), 6)
102         self.failUnlessEqual(f(32, 1), 32)
103         self.failUnlessEqual(f(32, 2), 32)
104         self.failUnlessEqual(f(32, 3), 33)
105         self.failUnlessEqual(f(32, 4), 32)
106         self.failUnlessEqual(f(32, 5), 35)
107         self.failUnlessEqual(f(32, 6), 36)
108         self.failUnlessEqual(f(32, 7), 35)
109         self.failUnlessEqual(f(32, 8), 32)
110         self.failUnlessEqual(f(32, 9), 36)
111         self.failUnlessEqual(f(32, 10), 40)
112         self.failUnlessEqual(f(32, 11), 33)
113         self.failUnlessEqual(f(32, 12), 36)
114         self.failUnlessEqual(f(32, 13), 39)
115         self.failUnlessEqual(f(32, 14), 42)
116         self.failUnlessEqual(f(32, 15), 45)
117         self.failUnlessEqual(f(32, 16), 32)
118         self.failUnlessEqual(f(32, 17), 34)
119         self.failUnlessEqual(f(32, 18), 36)
120         self.failUnlessEqual(f(32, 589), 589)
121
122     def test_pad_size(self):
123         f = mathutil.pad_size
124         self.failUnlessEqual(f(0, 4), 0)
125         self.failUnlessEqual(f(1, 4), 3)
126         self.failUnlessEqual(f(2, 4), 2)
127         self.failUnlessEqual(f(3, 4), 1)
128         self.failUnlessEqual(f(4, 4), 0)
129         self.failUnlessEqual(f(5, 4), 3)
130
131     def test_is_power_of_k(self):
132         f = mathutil.is_power_of_k
133         for i in range(1, 100):
134             if i in (1, 2, 4, 8, 16, 32, 64):
135                 self.failUnless(f(i, 2), "but %d *is* a power of 2" % i)
136             else:
137                 self.failIf(f(i, 2), "but %d is *not* a power of 2" % i)
138         for i in range(1, 100):
139             if i in (1, 3, 9, 27, 81):
140                 self.failUnless(f(i, 3), "but %d *is* a power of 3" % i)
141             else:
142                 self.failIf(f(i, 3), "but %d is *not* a power of 3" % i)
143
144     def test_next_power_of_k(self):
145         f = mathutil.next_power_of_k
146         self.failUnlessEqual(f(0,2), 1)
147         self.failUnlessEqual(f(1,2), 1)
148         self.failUnlessEqual(f(2,2), 2)
149         self.failUnlessEqual(f(3,2), 4)
150         self.failUnlessEqual(f(4,2), 4)
151         for i in range(5, 8): self.failUnlessEqual(f(i,2), 8, "%d" % i)
152         for i in range(9, 16): self.failUnlessEqual(f(i,2), 16, "%d" % i)
153         for i in range(17, 32): self.failUnlessEqual(f(i,2), 32, "%d" % i)
154         for i in range(33, 64): self.failUnlessEqual(f(i,2), 64, "%d" % i)
155         for i in range(65, 100): self.failUnlessEqual(f(i,2), 128, "%d" % i)
156
157         self.failUnlessEqual(f(0,3), 1)
158         self.failUnlessEqual(f(1,3), 1)
159         self.failUnlessEqual(f(2,3), 3)
160         self.failUnlessEqual(f(3,3), 3)
161         for i in range(4, 9): self.failUnlessEqual(f(i,3), 9, "%d" % i)
162         for i in range(10, 27): self.failUnlessEqual(f(i,3), 27, "%d" % i)
163         for i in range(28, 81): self.failUnlessEqual(f(i,3), 81, "%d" % i)
164         for i in range(82, 200): self.failUnlessEqual(f(i,3), 243, "%d" % i)
165
166     def test_ave(self):
167         f = mathutil.ave
168         self.failUnlessEqual(f([1,2,3]), 2)
169         self.failUnlessEqual(f([0,0,0,4]), 1)
170         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
171
172     def test_round_sigfigs(self):
173         f = mathutil.round_sigfigs
174         self.failUnlessEqual(f(22.0/3, 4), 7.3330000000000002)
175
176 class Statistics(unittest.TestCase):
177     def should_assert(self, msg, func, *args, **kwargs):
178         try:
179             func(*args, **kwargs)
180             self.fail(msg)
181         except AssertionError:
182             pass
183
184     def failUnlessListEqual(self, a, b, msg = None):
185         self.failUnlessEqual(len(a), len(b))
186         for i in range(len(a)):
187             self.failUnlessEqual(a[i], b[i], msg)
188
189     def failUnlessListAlmostEqual(self, a, b, places = 7, msg = None):
190         self.failUnlessEqual(len(a), len(b))
191         for i in range(len(a)):
192             self.failUnlessAlmostEqual(a[i], b[i], places, msg)
193
194     def test_binomial_coeff(self):
195         f = statistics.binomial_coeff
196         self.failUnlessEqual(f(20, 0), 1)
197         self.failUnlessEqual(f(20, 1), 20)
198         self.failUnlessEqual(f(20, 2), 190)
199         self.failUnlessEqual(f(20, 8), f(20, 12))
200         self.should_assert("Should assert if n < k", f, 2, 3)
201
202     def test_binomial_distribution_pmf(self):
203         f = statistics.binomial_distribution_pmf
204
205         pmf_comp = f(2, .1)
206         pmf_stat = [0.81, 0.18, 0.01]
207         self.failUnlessListAlmostEqual(pmf_comp, pmf_stat)
208
209         # Summing across a PMF should give the total probability 1
210         self.failUnlessAlmostEqual(sum(pmf_comp), 1)
211         self.should_assert("Should assert if not 0<=p<=1", f, 1, -1)
212         self.should_assert("Should assert if n < 1", f, 0, .1)
213
214         out = StringIO()
215         statistics.print_pmf(pmf_comp, out=out)
216         lines = out.getvalue().splitlines()
217         self.failUnlessEqual(lines[0], "i=0: 0.81")
218         self.failUnlessEqual(lines[1], "i=1: 0.18")
219         self.failUnlessEqual(lines[2], "i=2: 0.01")
220
221     def test_survival_pmf(self):
222         f = statistics.survival_pmf
223         # Cross-check binomial-distribution method against convolution
224         # method.
225         p_list = [.9999] * 100 + [.99] * 50 + [.8] * 20
226         pmf1 = statistics.survival_pmf_via_conv(p_list)
227         pmf2 = statistics.survival_pmf_via_bd(p_list)
228         self.failUnlessListAlmostEqual(pmf1, pmf2)
229         self.failUnlessTrue(statistics.valid_pmf(pmf1))
230         self.should_assert("Should assert if p_i > 1", f, [1.1]);
231         self.should_assert("Should assert if p_i < 0", f, [-.1]);
232
233     def test_repair_count_pmf(self):
234         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
235         repair_pmf = statistics.repair_count_pmf(survival_pmf, 3)
236         # repair_pmf[0] == sum(survival_pmf[0,1,2,5])
237         # repair_pmf[1] == survival_pmf[4]
238         # repair_pmf[2] = survival_pmf[3]
239         self.failUnlessListAlmostEqual(repair_pmf,
240                                        [0.00001 + 0.00045 + 0.0081 + 0.59049,
241                                         .32805,
242                                         .0729,
243                                         0, 0, 0])
244
245     def test_repair_cost(self):
246         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
247         bwcost = statistics.bandwidth_cost_function
248         cost = statistics.mean_repair_cost(bwcost, 1000,
249                                            survival_pmf, 3, ul_dl_ratio=1.0)
250         self.failUnlessAlmostEqual(cost, 558.90)
251         cost = statistics.mean_repair_cost(bwcost, 1000,
252                                            survival_pmf, 3, ul_dl_ratio=8.0)
253         self.failUnlessAlmostEqual(cost, 1664.55)
254
255         # I haven't manually checked the math beyond here -warner
256         cost = statistics.eternal_repair_cost(bwcost, 1000,
257                                               survival_pmf, 3,
258                                               discount_rate=0, ul_dl_ratio=1.0)
259         self.failUnlessAlmostEqual(cost, 65292.056074766246)
260         cost = statistics.eternal_repair_cost(bwcost, 1000,
261                                               survival_pmf, 3,
262                                               discount_rate=0.05,
263                                               ul_dl_ratio=1.0)
264         self.failUnlessAlmostEqual(cost, 9133.6097158191551)
265
266     def test_convolve(self):
267         f = statistics.convolve
268         v1 = [ 1, 2, 3 ]
269         v2 = [ 4, 5, 6 ]
270         v3 = [ 7, 8 ]
271         v1v2result = [ 4, 13, 28, 27, 18 ]
272         # Convolution is commutative
273         r1 = f(v1, v2)
274         r2 = f(v2, v1)
275         self.failUnlessListEqual(r1, r2, "Convolution should be commutative")
276         self.failUnlessListEqual(r1, v1v2result, "Didn't match known result")
277         # Convolution is associative
278         r1 = f(f(v1, v2), v3)
279         r2 = f(v1, f(v2, v3))
280         self.failUnlessListEqual(r1, r2, "Convolution should be associative")
281         # Convolution is distributive
282         r1 = f(v3, [ a + b for a, b in zip(v1, v2) ])
283         tmp1 = f(v3, v1)
284         tmp2 = f(v3, v2)
285         r2 = [ a + b for a, b in zip(tmp1, tmp2) ]
286         self.failUnlessListEqual(r1, r2, "Convolution should be distributive")
287         # Convolution is scalar multiplication associative
288         tmp1 = f(v1, v2)
289         r1 = [ a * 4 for a in tmp1 ]
290         tmp2 = [ a * 4 for a in v1 ]
291         r2 = f(tmp2, v2)
292         self.failUnlessListEqual(r1, r2, "Convolution should be scalar multiplication associative")
293
294     def test_find_k(self):
295         f = statistics.find_k
296         g = statistics.pr_file_loss
297         plist = [.9] * 10 + [.8] * 10 # N=20
298         t = .0001
299         k = f(plist, t)
300         self.failUnlessEqual(k, 10)
301         self.failUnless(g(plist, k) < t)
302
303     def test_pr_file_loss(self):
304         f = statistics.pr_file_loss
305         plist = [.5] * 10
306         self.failUnlessEqual(f(plist, 3), .0546875)
307
308     def test_pr_backup_file_loss(self):
309         f = statistics.pr_backup_file_loss
310         plist = [.5] * 10
311         self.failUnlessEqual(f(plist, .5, 3), .02734375)
312
313
314 class Asserts(unittest.TestCase):
315     def should_assert(self, func, *args, **kwargs):
316         try:
317             func(*args, **kwargs)
318         except AssertionError, e:
319             return str(e)
320         except Exception, e:
321             self.fail("assert failed with non-AssertionError: %s" % e)
322         self.fail("assert was not caught")
323
324     def should_not_assert(self, func, *args, **kwargs):
325         try:
326             func(*args, **kwargs)
327         except AssertionError, e:
328             self.fail("assertion fired when it should not have: %s" % e)
329         except Exception, e:
330             self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
331         return # we're happy
332
333
334     def test_assert(self):
335         f = assertutil._assert
336         self.should_assert(f)
337         self.should_assert(f, False)
338         self.should_not_assert(f, True)
339
340         m = self.should_assert(f, False, "message")
341         self.failUnlessEqual(m, "'message' <type 'str'>", m)
342         m = self.should_assert(f, False, "message1", othermsg=12)
343         self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
344         m = self.should_assert(f, False, othermsg="message2")
345         self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
346
347     def test_precondition(self):
348         f = assertutil.precondition
349         self.should_assert(f)
350         self.should_assert(f, False)
351         self.should_not_assert(f, True)
352
353         m = self.should_assert(f, False, "message")
354         self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
355         m = self.should_assert(f, False, "message1", othermsg=12)
356         self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
357         m = self.should_assert(f, False, othermsg="message2")
358         self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
359
360     def test_postcondition(self):
361         f = assertutil.postcondition
362         self.should_assert(f)
363         self.should_assert(f, False)
364         self.should_not_assert(f, True)
365
366         m = self.should_assert(f, False, "message")
367         self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
368         m = self.should_assert(f, False, "message1", othermsg=12)
369         self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
370         m = self.should_assert(f, False, othermsg="message2")
371         self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
372
373 class FileUtil(unittest.TestCase):
374     def mkdir(self, basedir, path, mode=0777):
375         fn = os.path.join(basedir, path)
376         fileutil.make_dirs(fn, mode)
377
378     def touch(self, basedir, path, mode=None, data="touch\n"):
379         fn = os.path.join(basedir, path)
380         f = open(fn, "w")
381         f.write(data)
382         f.close()
383         if mode is not None:
384             os.chmod(fn, mode)
385
386     def test_rm_dir(self):
387         basedir = "util/FileUtil/test_rm_dir"
388         fileutil.make_dirs(basedir)
389         # create it again to test idempotency
390         fileutil.make_dirs(basedir)
391         d = os.path.join(basedir, "doomed")
392         self.mkdir(d, "a/b")
393         self.touch(d, "a/b/1.txt")
394         self.touch(d, "a/b/2.txt", 0444)
395         self.touch(d, "a/b/3.txt", 0)
396         self.mkdir(d, "a/c")
397         self.touch(d, "a/c/1.txt")
398         self.touch(d, "a/c/2.txt", 0444)
399         self.touch(d, "a/c/3.txt", 0)
400         os.chmod(os.path.join(d, "a/c"), 0444)
401         self.mkdir(d, "a/d")
402         self.touch(d, "a/d/1.txt")
403         self.touch(d, "a/d/2.txt", 0444)
404         self.touch(d, "a/d/3.txt", 0)
405         os.chmod(os.path.join(d, "a/d"), 0)
406
407         fileutil.rm_dir(d)
408         self.failIf(os.path.exists(d))
409         # remove it again to test idempotency
410         fileutil.rm_dir(d)
411
412     def test_remove_if_possible(self):
413         basedir = "util/FileUtil/test_remove_if_possible"
414         fileutil.make_dirs(basedir)
415         self.touch(basedir, "here")
416         fn = os.path.join(basedir, "here")
417         fileutil.remove_if_possible(fn)
418         self.failIf(os.path.exists(fn))
419         fileutil.remove_if_possible(fn) # should be idempotent
420         fileutil.rm_dir(basedir)
421         fileutil.remove_if_possible(fn) # should survive errors
422
423     def test_write_atomically(self):
424         basedir = "util/FileUtil/test_write_atomically"
425         fileutil.make_dirs(basedir)
426         fn = os.path.join(basedir, "here")
427         fileutil.write_atomically(fn, "one")
428         self.failUnlessEqual(fileutil.read(fn), "one")
429         fileutil.write_atomically(fn, "two", mode="") # non-binary
430         self.failUnlessEqual(fileutil.read(fn), "two")
431
432     def test_open_or_create(self):
433         basedir = "util/FileUtil/test_open_or_create"
434         fileutil.make_dirs(basedir)
435         fn = os.path.join(basedir, "here")
436         f = fileutil.open_or_create(fn)
437         f.write("stuff.")
438         f.close()
439         f = fileutil.open_or_create(fn)
440         f.seek(0, os.SEEK_END)
441         f.write("more.")
442         f.close()
443         f = open(fn, "r")
444         data = f.read()
445         f.close()
446         self.failUnlessEqual(data, "stuff.more.")
447
448     def test_NamedTemporaryDirectory(self):
449         basedir = "util/FileUtil/test_NamedTemporaryDirectory"
450         fileutil.make_dirs(basedir)
451         td = fileutil.NamedTemporaryDirectory(dir=basedir)
452         name = td.name
453         self.failUnless(basedir in name)
454         self.failUnless(basedir in repr(td))
455         self.failUnless(os.path.isdir(name))
456         del td
457         # it is conceivable that we need to force gc here, but I'm not sure
458         self.failIf(os.path.isdir(name))
459
460     def test_rename(self):
461         basedir = "util/FileUtil/test_rename"
462         fileutil.make_dirs(basedir)
463         self.touch(basedir, "here")
464         fn = os.path.join(basedir, "here")
465         fn2 = os.path.join(basedir, "there")
466         fileutil.rename(fn, fn2)
467         self.failIf(os.path.exists(fn))
468         self.failUnless(os.path.exists(fn2))
469
470     def test_du(self):
471         basedir = "util/FileUtil/test_du"
472         fileutil.make_dirs(basedir)
473         d = os.path.join(basedir, "space-consuming")
474         self.mkdir(d, "a/b")
475         self.touch(d, "a/b/1.txt", data="a"*10)
476         self.touch(d, "a/b/2.txt", data="b"*11)
477         self.mkdir(d, "a/c")
478         self.touch(d, "a/c/1.txt", data="c"*12)
479         self.touch(d, "a/c/2.txt", data="d"*13)
480
481         used = fileutil.du(basedir)
482         self.failUnlessEqual(10+11+12+13, used)
483
484     def test_abspath_expanduser_unicode(self):
485         self.failUnlessRaises(AssertionError, fileutil.abspath_expanduser_unicode, "bytestring")
486
487         saved_cwd = os.path.normpath(os.getcwdu())
488         abspath_cwd = fileutil.abspath_expanduser_unicode(u".")
489         self.failUnless(isinstance(saved_cwd, unicode), saved_cwd)
490         self.failUnless(isinstance(abspath_cwd, unicode), abspath_cwd)
491         self.failUnlessEqual(abspath_cwd, saved_cwd)
492
493         # adapted from <http://svn.python.org/view/python/branches/release26-maint/Lib/test/test_posixpath.py?view=markup&pathrev=78279#test_abspath>
494
495         self.failUnlessIn(u"foo", fileutil.abspath_expanduser_unicode(u"foo"))
496         self.failIfIn(u"~", fileutil.abspath_expanduser_unicode(u"~"))
497
498         cwds = ['cwd']
499         try:
500             cwds.append(u'\xe7w\xf0'.encode(sys.getfilesystemencoding()
501                                             or 'ascii'))
502         except UnicodeEncodeError:
503             pass # the cwd can't be encoded -- test with ascii cwd only
504
505         for cwd in cwds:
506             try:
507                 os.mkdir(cwd)
508                 os.chdir(cwd)
509                 for upath in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\', u'~'):
510                     uabspath = fileutil.abspath_expanduser_unicode(upath)
511                     self.failUnless(isinstance(uabspath, unicode), uabspath)
512             finally:
513                 os.chdir(saved_cwd)
514
515     def test_disk_stats(self):
516         avail = fileutil.get_available_space('.', 2**14)
517         if avail == 0:
518             raise unittest.SkipTest("This test will spuriously fail there is no disk space left.")
519
520         disk = fileutil.get_disk_stats('.', 2**13)
521         self.failUnless(disk['total'] > 0, disk['total'])
522         self.failUnless(disk['used'] > 0, disk['used'])
523         self.failUnless(disk['free_for_root'] > 0, disk['free_for_root'])
524         self.failUnless(disk['free_for_nonroot'] > 0, disk['free_for_nonroot'])
525         self.failUnless(disk['avail'] > 0, disk['avail'])
526
527     def test_disk_stats_avail_nonnegative(self):
528         # This test will spuriously fail if you have more than 2^128
529         # bytes of available space on your filesystem.
530         disk = fileutil.get_disk_stats('.', 2**128)
531         self.failUnlessEqual(disk['avail'], 0)
532
533 class PollMixinTests(unittest.TestCase):
534     def setUp(self):
535         self.pm = pollmixin.PollMixin()
536
537     def test_PollMixin_True(self):
538         d = self.pm.poll(check_f=lambda : True,
539                          pollinterval=0.1)
540         return d
541
542     def test_PollMixin_False_then_True(self):
543         i = iter([False, True])
544         d = self.pm.poll(check_f=i.next,
545                          pollinterval=0.1)
546         return d
547
548     def test_timeout(self):
549         d = self.pm.poll(check_f=lambda: False,
550                          pollinterval=0.01,
551                          timeout=1)
552         def _suc(res):
553             self.fail("poll should have failed, not returned %s" % (res,))
554         def _err(f):
555             f.trap(pollmixin.TimeoutError)
556             return None # success
557         d.addCallbacks(_suc, _err)
558         return d
559
560 class DeferredUtilTests(unittest.TestCase):
561     def test_gather_results(self):
562         d1 = defer.Deferred()
563         d2 = defer.Deferred()
564         res = deferredutil.gatherResults([d1, d2])
565         d1.errback(ValueError("BAD"))
566         def _callb(res):
567             self.fail("Should have errbacked, not resulted in %s" % (res,))
568         def _errb(thef):
569             thef.trap(ValueError)
570         res.addCallbacks(_callb, _errb)
571         return res
572
573     def test_success(self):
574         d1, d2 = defer.Deferred(), defer.Deferred()
575         good = []
576         bad = []
577         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
578         dlss.addCallbacks(good.append, bad.append)
579         d1.callback(1)
580         d2.callback(2)
581         self.failUnlessEqual(good, [[1,2]])
582         self.failUnlessEqual(bad, [])
583
584     def test_failure(self):
585         d1, d2 = defer.Deferred(), defer.Deferred()
586         good = []
587         bad = []
588         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
589         dlss.addCallbacks(good.append, bad.append)
590         d1.addErrback(lambda _ignore: None)
591         d2.addErrback(lambda _ignore: None)
592         d1.callback(1)
593         d2.errback(ValueError())
594         self.failUnlessEqual(good, [])
595         self.failUnlessEqual(len(bad), 1)
596         f = bad[0]
597         self.failUnless(isinstance(f, Failure))
598         self.failUnless(f.check(ValueError))
599
600 class HashUtilTests(unittest.TestCase):
601
602     def test_random_key(self):
603         k = hashutil.random_key()
604         self.failUnlessEqual(len(k), hashutil.KEYLEN)
605
606     def test_sha256d(self):
607         h1 = hashutil.tagged_hash("tag1", "value")
608         h2 = hashutil.tagged_hasher("tag1")
609         h2.update("value")
610         h2a = h2.digest()
611         h2b = h2.digest()
612         self.failUnlessEqual(h1, h2a)
613         self.failUnlessEqual(h2a, h2b)
614
615     def test_sha256d_truncated(self):
616         h1 = hashutil.tagged_hash("tag1", "value", 16)
617         h2 = hashutil.tagged_hasher("tag1", 16)
618         h2.update("value")
619         h2 = h2.digest()
620         self.failUnlessEqual(len(h1), 16)
621         self.failUnlessEqual(len(h2), 16)
622         self.failUnlessEqual(h1, h2)
623
624     def test_chk(self):
625         h1 = hashutil.convergence_hash(3, 10, 1000, "data", "secret")
626         h2 = hashutil.convergence_hasher(3, 10, 1000, "secret")
627         h2.update("data")
628         h2 = h2.digest()
629         self.failUnlessEqual(h1, h2)
630
631     def test_hashers(self):
632         h1 = hashutil.block_hash("foo")
633         h2 = hashutil.block_hasher()
634         h2.update("foo")
635         self.failUnlessEqual(h1, h2.digest())
636
637         h1 = hashutil.uri_extension_hash("foo")
638         h2 = hashutil.uri_extension_hasher()
639         h2.update("foo")
640         self.failUnlessEqual(h1, h2.digest())
641
642         h1 = hashutil.plaintext_hash("foo")
643         h2 = hashutil.plaintext_hasher()
644         h2.update("foo")
645         self.failUnlessEqual(h1, h2.digest())
646
647         h1 = hashutil.crypttext_hash("foo")
648         h2 = hashutil.crypttext_hasher()
649         h2.update("foo")
650         self.failUnlessEqual(h1, h2.digest())
651
652         h1 = hashutil.crypttext_segment_hash("foo")
653         h2 = hashutil.crypttext_segment_hasher()
654         h2.update("foo")
655         self.failUnlessEqual(h1, h2.digest())
656
657         h1 = hashutil.plaintext_segment_hash("foo")
658         h2 = hashutil.plaintext_segment_hasher()
659         h2.update("foo")
660         self.failUnlessEqual(h1, h2.digest())
661
662     def test_constant_time_compare(self):
663         self.failUnless(hashutil.constant_time_compare("a", "a"))
664         self.failUnless(hashutil.constant_time_compare("ab", "ab"))
665         self.failIf(hashutil.constant_time_compare("a", "b"))
666         self.failIf(hashutil.constant_time_compare("a", "aa"))
667
668     def _testknown(self, hashf, expected_a, *args):
669         got = hashf(*args)
670         got_a = base32.b2a(got)
671         self.failUnlessEqual(got_a, expected_a)
672
673     def test_known_answers(self):
674         # assert backwards compatibility
675         self._testknown(hashutil.storage_index_hash, "qb5igbhcc5esa6lwqorsy7e6am", "")
676         self._testknown(hashutil.block_hash, "msjr5bh4evuh7fa3zw7uovixfbvlnstr5b65mrerwfnvjxig2jvq", "")
677         self._testknown(hashutil.uri_extension_hash, "wthsu45q7zewac2mnivoaa4ulh5xvbzdmsbuyztq2a5fzxdrnkka", "")
678         self._testknown(hashutil.plaintext_hash, "5lz5hwz3qj3af7n6e3arblw7xzutvnd3p3fjsngqjcb7utf3x3da", "")
679         self._testknown(hashutil.crypttext_hash, "itdj6e4njtkoiavlrmxkvpreosscssklunhwtvxn6ggho4rkqwga", "")
680         self._testknown(hashutil.crypttext_segment_hash, "aovy5aa7jej6ym5ikgwyoi4pxawnoj3wtaludjz7e2nb5xijb7aa", "")
681         self._testknown(hashutil.plaintext_segment_hash, "4fdgf6qruaisyukhqcmoth4t3li6bkolbxvjy4awwcpprdtva7za", "")
682         self._testknown(hashutil.convergence_hash, "3mo6ni7xweplycin6nowynw2we", 3, 10, 100, "", "converge")
683         self._testknown(hashutil.my_renewal_secret_hash, "ujhr5k5f7ypkp67jkpx6jl4p47pyta7hu5m527cpcgvkafsefm6q", "")
684         self._testknown(hashutil.my_cancel_secret_hash, "rjwzmafe2duixvqy6h47f5wfrokdziry6zhx4smew4cj6iocsfaa", "")
685         self._testknown(hashutil.file_renewal_secret_hash, "hzshk2kf33gzbd5n3a6eszkf6q6o6kixmnag25pniusyaulqjnia", "", "si")
686         self._testknown(hashutil.file_cancel_secret_hash, "bfciwvr6w7wcavsngxzxsxxaszj72dej54n4tu2idzp6b74g255q", "", "si")
687         self._testknown(hashutil.bucket_renewal_secret_hash, "e7imrzgzaoashsncacvy3oysdd2m5yvtooo4gmj4mjlopsazmvuq", "", "\x00"*20)
688         self._testknown(hashutil.bucket_cancel_secret_hash, "dvdujeyxeirj6uux6g7xcf4lvesk632aulwkzjar7srildvtqwma", "", "\x00"*20)
689         self._testknown(hashutil.hmac, "c54ypfi6pevb3nvo6ba42jtglpkry2kbdopqsi7dgrm4r7tw5sra", "tag", "")
690         self._testknown(hashutil.mutable_rwcap_key_hash, "6rvn2iqrghii5n4jbbwwqqsnqu", "iv", "wk")
691         self._testknown(hashutil.ssk_writekey_hash, "ykpgmdbpgbb6yqz5oluw2q26ye", "")
692         self._testknown(hashutil.ssk_write_enabler_master_hash, "izbfbfkoait4dummruol3gy2bnixrrrslgye6ycmkuyujnenzpia", "")
693         self._testknown(hashutil.ssk_write_enabler_hash, "fuu2dvx7g6gqu5x22vfhtyed7p4pd47y5hgxbqzgrlyvxoev62tq", "wk", "\x00"*20)
694         self._testknown(hashutil.ssk_pubkey_fingerprint_hash, "3opzw4hhm2sgncjx224qmt5ipqgagn7h5zivnfzqycvgqgmgz35q", "")
695         self._testknown(hashutil.ssk_readkey_hash, "vugid4as6qbqgeq2xczvvcedai", "")
696         self._testknown(hashutil.ssk_readkey_data_hash, "73wsaldnvdzqaf7v4pzbr2ae5a", "iv", "rk")
697         self._testknown(hashutil.ssk_storage_index_hash, "j7icz6kigb6hxrej3tv4z7ayym", "")
698
699
700 class Abbreviate(unittest.TestCase):
701     def test_time(self):
702         a = abbreviate.abbreviate_time
703         self.failUnlessEqual(a(None), "unknown")
704         self.failUnlessEqual(a(0), "0 seconds")
705         self.failUnlessEqual(a(1), "1 second")
706         self.failUnlessEqual(a(2), "2 seconds")
707         self.failUnlessEqual(a(119), "119 seconds")
708         MIN = 60
709         self.failUnlessEqual(a(2*MIN), "2 minutes")
710         self.failUnlessEqual(a(60*MIN), "60 minutes")
711         self.failUnlessEqual(a(179*MIN), "179 minutes")
712         HOUR = 60*MIN
713         self.failUnlessEqual(a(180*MIN), "3 hours")
714         self.failUnlessEqual(a(4*HOUR), "4 hours")
715         DAY = 24*HOUR
716         MONTH = 30*DAY
717         self.failUnlessEqual(a(2*DAY), "2 days")
718         self.failUnlessEqual(a(2*MONTH), "2 months")
719         YEAR = 365*DAY
720         self.failUnlessEqual(a(5*YEAR), "5 years")
721
722     def test_space(self):
723         tests_si = [(None, "unknown"),
724                     (0, "0 B"),
725                     (1, "1 B"),
726                     (999, "999 B"),
727                     (1000, "1000 B"),
728                     (1023, "1023 B"),
729                     (1024, "1.02 kB"),
730                     (20*1000, "20.00 kB"),
731                     (1024*1024, "1.05 MB"),
732                     (1000*1000, "1.00 MB"),
733                     (1000*1000*1000, "1.00 GB"),
734                     (1000*1000*1000*1000, "1.00 TB"),
735                     (1000*1000*1000*1000*1000, "1.00 PB"),
736                     (1234567890123456, "1.23 PB"),
737                     ]
738         for (x, expected) in tests_si:
739             got = abbreviate.abbreviate_space(x, SI=True)
740             self.failUnlessEqual(got, expected)
741
742         tests_base1024 = [(None, "unknown"),
743                           (0, "0 B"),
744                           (1, "1 B"),
745                           (999, "999 B"),
746                           (1000, "1000 B"),
747                           (1023, "1023 B"),
748                           (1024, "1.00 kiB"),
749                           (20*1024, "20.00 kiB"),
750                           (1000*1000, "976.56 kiB"),
751                           (1024*1024, "1.00 MiB"),
752                           (1024*1024*1024, "1.00 GiB"),
753                           (1024*1024*1024*1024, "1.00 TiB"),
754                           (1000*1000*1000*1000*1000, "909.49 TiB"),
755                           (1024*1024*1024*1024*1024, "1.00 PiB"),
756                           (1234567890123456, "1.10 PiB"),
757                     ]
758         for (x, expected) in tests_base1024:
759             got = abbreviate.abbreviate_space(x, SI=False)
760             self.failUnlessEqual(got, expected)
761
762         self.failUnlessEqual(abbreviate.abbreviate_space_both(1234567),
763                              "(1.23 MB, 1.18 MiB)")
764
765     def test_parse_space(self):
766         p = abbreviate.parse_abbreviated_size
767         self.failUnlessEqual(p(""), None)
768         self.failUnlessEqual(p(None), None)
769         self.failUnlessEqual(p("123"), 123)
770         self.failUnlessEqual(p("123B"), 123)
771         self.failUnlessEqual(p("2K"), 2000)
772         self.failUnlessEqual(p("2kb"), 2000)
773         self.failUnlessEqual(p("2KiB"), 2048)
774         self.failUnlessEqual(p("10MB"), 10*1000*1000)
775         self.failUnlessEqual(p("10MiB"), 10*1024*1024)
776         self.failUnlessEqual(p("5G"), 5*1000*1000*1000)
777         self.failUnlessEqual(p("4GiB"), 4*1024*1024*1024)
778         e = self.failUnlessRaises(ValueError, p, "12 cubits")
779         self.failUnless("12 cubits" in str(e))
780
781 class Limiter(unittest.TestCase):
782     timeout = 480 # This takes longer than 240 seconds on Francois's arm box.
783
784     def job(self, i, foo):
785         self.calls.append( (i, foo) )
786         self.simultaneous += 1
787         self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
788         d = defer.Deferred()
789         def _done():
790             self.simultaneous -= 1
791             d.callback("done %d" % i)
792         reactor.callLater(1.0, _done)
793         return d
794
795     def bad_job(self, i, foo):
796         raise ValueError("bad_job %d" % i)
797
798     def test_limiter(self):
799         self.calls = []
800         self.simultaneous = 0
801         self.peak_simultaneous = 0
802         l = limiter.ConcurrencyLimiter()
803         dl = []
804         for i in range(20):
805             dl.append(l.add(self.job, i, foo=str(i)))
806         d = defer.DeferredList(dl, fireOnOneErrback=True)
807         def _done(res):
808             self.failUnlessEqual(self.simultaneous, 0)
809             self.failUnless(self.peak_simultaneous <= 10)
810             self.failUnlessEqual(len(self.calls), 20)
811             for i in range(20):
812                 self.failUnless( (i, str(i)) in self.calls)
813         d.addCallback(_done)
814         return d
815
816     def test_errors(self):
817         self.calls = []
818         self.simultaneous = 0
819         self.peak_simultaneous = 0
820         l = limiter.ConcurrencyLimiter()
821         dl = []
822         for i in range(20):
823             dl.append(l.add(self.job, i, foo=str(i)))
824         d2 = l.add(self.bad_job, 21, "21")
825         d = defer.DeferredList(dl, fireOnOneErrback=True)
826         def _most_done(res):
827             results = []
828             for (success, result) in res:
829                 self.failUnlessEqual(success, True)
830                 results.append(result)
831             results.sort()
832             expected_results = ["done %d" % i for i in range(20)]
833             expected_results.sort()
834             self.failUnlessEqual(results, expected_results)
835             self.failUnless(self.peak_simultaneous <= 10)
836             self.failUnlessEqual(len(self.calls), 20)
837             for i in range(20):
838                 self.failUnless( (i, str(i)) in self.calls)
839             def _good(res):
840                 self.fail("should have failed, not got %s" % (res,))
841             def _err(f):
842                 f.trap(ValueError)
843                 self.failUnless("bad_job 21" in str(f))
844             d2.addCallbacks(_good, _err)
845             return d2
846         d.addCallback(_most_done)
847         def _all_done(res):
848             self.failUnlessEqual(self.simultaneous, 0)
849             self.failUnless(self.peak_simultaneous <= 10)
850             self.failUnlessEqual(len(self.calls), 20)
851             for i in range(20):
852                 self.failUnless( (i, str(i)) in self.calls)
853         d.addCallback(_all_done)
854         return d
855
856 class TimeFormat(unittest.TestCase):
857     def test_epoch(self):
858         return self._help_test_epoch()
859
860     def test_epoch_in_London(self):
861         # Europe/London is a particularly troublesome timezone.  Nowadays, its
862         # offset from GMT is 0.  But in 1970, its offset from GMT was 1.
863         # (Apparently in 1970 Britain had redefined standard time to be GMT+1
864         # and stayed in standard time all year round, whereas today
865         # Europe/London standard time is GMT and Europe/London Daylight
866         # Savings Time is GMT+1.)  The current implementation of
867         # time_format.iso_utc_time_to_localseconds() breaks if the timezone is
868         # Europe/London.  (As soon as this unit test is done then I'll change
869         # that implementation to something that works even in this case...)
870         origtz = os.environ.get('TZ')
871         os.environ['TZ'] = "Europe/London"
872         if hasattr(time, 'tzset'):
873             time.tzset()
874         try:
875             return self._help_test_epoch()
876         finally:
877             if origtz is None:
878                 del os.environ['TZ']
879             else:
880                 os.environ['TZ'] = origtz
881             if hasattr(time, 'tzset'):
882                 time.tzset()
883
884     def _help_test_epoch(self):
885         origtzname = time.tzname
886         s = time_format.iso_utc_time_to_seconds("1970-01-01T00:00:01")
887         self.failUnlessEqual(s, 1.0)
888         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01")
889         self.failUnlessEqual(s, 1.0)
890         s = time_format.iso_utc_time_to_seconds("1970-01-01 00:00:01")
891         self.failUnlessEqual(s, 1.0)
892
893         self.failUnlessEqual(time_format.iso_utc(1.0), "1970-01-01_00:00:01")
894         self.failUnlessEqual(time_format.iso_utc(1.0, sep=" "),
895                              "1970-01-01 00:00:01")
896
897         now = time.time()
898         isostr = time_format.iso_utc(now)
899         timestamp = time_format.iso_utc_time_to_seconds(isostr)
900         self.failUnlessEqual(int(timestamp), int(now))
901
902         def my_time():
903             return 1.0
904         self.failUnlessEqual(time_format.iso_utc(t=my_time),
905                              "1970-01-01_00:00:01")
906         e = self.failUnlessRaises(ValueError,
907                                   time_format.iso_utc_time_to_seconds,
908                                   "invalid timestring")
909         self.failUnless("not a complete ISO8601 timestamp" in str(e))
910         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01.500")
911         self.failUnlessEqual(s, 1.5)
912
913         # Look for daylight-savings-related errors.
914         thatmomentinmarch = time_format.iso_utc_time_to_seconds("2009-03-20 21:49:02.226536")
915         self.failUnlessEqual(thatmomentinmarch, 1237585742.226536)
916         self.failUnlessEqual(origtzname, time.tzname)
917
918     def test_iso_utc(self):
919         when = 1266760143.7841301
920         out = time_format.iso_utc_date(when)
921         self.failUnlessEqual(out, "2010-02-21")
922         out = time_format.iso_utc_date(t=lambda: when)
923         self.failUnlessEqual(out, "2010-02-21")
924         out = time_format.iso_utc(when)
925         self.failUnlessEqual(out, "2010-02-21_13:49:03.784130")
926         out = time_format.iso_utc(when, sep="-")
927         self.failUnlessEqual(out, "2010-02-21-13:49:03.784130")
928
929     def test_parse_duration(self):
930         p = time_format.parse_duration
931         DAY = 24*60*60
932         self.failUnlessEqual(p("1 day"), DAY)
933         self.failUnlessEqual(p("2 days"), 2*DAY)
934         self.failUnlessEqual(p("3 months"), 3*31*DAY)
935         self.failUnlessEqual(p("4 mo"), 4*31*DAY)
936         self.failUnlessEqual(p("5 years"), 5*365*DAY)
937         e = self.failUnlessRaises(ValueError, p, "123")
938         self.failUnlessIn("no unit (like day, month, or year) in '123'",
939                           str(e))
940
941     def test_parse_date(self):
942         self.failUnlessEqual(time_format.parse_date("2010-02-21"), 1266710400)
943
944 class CacheDir(unittest.TestCase):
945     def test_basic(self):
946         basedir = "test_util/CacheDir/test_basic"
947
948         def _failIfExists(name):
949             absfn = os.path.join(basedir, name)
950             self.failIf(os.path.exists(absfn),
951                         "%s exists but it shouldn't" % absfn)
952
953         def _failUnlessExists(name):
954             absfn = os.path.join(basedir, name)
955             self.failUnless(os.path.exists(absfn),
956                             "%s doesn't exist but it should" % absfn)
957
958         cdm = cachedir.CacheDirectoryManager(basedir)
959         a = cdm.get_file("a")
960         b = cdm.get_file("b")
961         c = cdm.get_file("c")
962         f = open(a.get_filename(), "wb"); f.write("hi"); f.close(); del f
963         f = open(b.get_filename(), "wb"); f.write("hi"); f.close(); del f
964         f = open(c.get_filename(), "wb"); f.write("hi"); f.close(); del f
965
966         _failUnlessExists("a")
967         _failUnlessExists("b")
968         _failUnlessExists("c")
969
970         cdm.check()
971
972         _failUnlessExists("a")
973         _failUnlessExists("b")
974         _failUnlessExists("c")
975
976         del a
977         # this file won't be deleted yet, because it isn't old enough
978         cdm.check()
979         _failUnlessExists("a")
980         _failUnlessExists("b")
981         _failUnlessExists("c")
982
983         # we change the definition of "old" to make everything old
984         cdm.old = -10
985
986         cdm.check()
987         _failIfExists("a")
988         _failUnlessExists("b")
989         _failUnlessExists("c")
990
991         cdm.old = 60*60
992
993         del b
994
995         cdm.check()
996         _failIfExists("a")
997         _failUnlessExists("b")
998         _failUnlessExists("c")
999
1000         b2 = cdm.get_file("b")
1001
1002         cdm.check()
1003         _failIfExists("a")
1004         _failUnlessExists("b")
1005         _failUnlessExists("c")
1006         del b2
1007
1008 ctr = [0]
1009 class EqButNotIs:
1010     def __init__(self, x):
1011         self.x = x
1012         self.hash = ctr[0]
1013         ctr[0] += 1
1014     def __repr__(self):
1015         return "<%s %s>" % (self.__class__.__name__, self.x,)
1016     def __hash__(self):
1017         return self.hash
1018     def __le__(self, other):
1019         return self.x <= other
1020     def __lt__(self, other):
1021         return self.x < other
1022     def __ge__(self, other):
1023         return self.x >= other
1024     def __gt__(self, other):
1025         return self.x > other
1026     def __ne__(self, other):
1027         return self.x != other
1028     def __eq__(self, other):
1029         return self.x == other
1030
1031 class DictUtil(unittest.TestCase):
1032     def _help_test_empty_dict(self, klass):
1033         d1 = klass()
1034         d2 = klass({})
1035
1036         self.failUnless(d1 == d2, "d1: %r, d2: %r" % (d1, d2,))
1037         self.failUnless(len(d1) == 0)
1038         self.failUnless(len(d2) == 0)
1039
1040     def _help_test_nonempty_dict(self, klass):
1041         d1 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1042         d2 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1043
1044         self.failUnless(d1 == d2)
1045         self.failUnless(len(d1) == 3, "%s, %s" % (len(d1), d1,))
1046         self.failUnless(len(d2) == 3)
1047
1048     def _help_test_eq_but_notis(self, klass):
1049         d = klass({'a': 3, 'b': EqButNotIs(3), 'c': 3})
1050         d.pop('b')
1051
1052         d.clear()
1053         d['a'] = 3
1054         d['b'] = EqButNotIs(3)
1055         d['c'] = 3
1056         d.pop('b')
1057
1058         d.clear()
1059         d['b'] = EqButNotIs(3)
1060         d['a'] = 3
1061         d['c'] = 3
1062         d.pop('b')
1063
1064         d.clear()
1065         d['a'] = EqButNotIs(3)
1066         d['c'] = 3
1067         d['a'] = 3
1068
1069         d.clear()
1070         fake3 = EqButNotIs(3)
1071         fake7 = EqButNotIs(7)
1072         d[fake3] = fake7
1073         d[3] = 7
1074         d[3] = 8
1075         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1076         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1077         # The real 7 should have been ejected by the d[3] = 8.
1078         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1079         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1080         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1081         d[fake3] = 8
1082
1083         d.clear()
1084         d[3] = 7
1085         fake3 = EqButNotIs(3)
1086         fake7 = EqButNotIs(7)
1087         d[fake3] = fake7
1088         d[3] = 8
1089         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1090         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1091         # The real 7 should have been ejected by the d[3] = 8.
1092         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1093         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1094         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1095         d[fake3] = 8
1096
1097     def test_all(self):
1098         self._help_test_eq_but_notis(dictutil.UtilDict)
1099         self._help_test_eq_but_notis(dictutil.NumDict)
1100         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1101         self._help_test_nonempty_dict(dictutil.UtilDict)
1102         self._help_test_nonempty_dict(dictutil.NumDict)
1103         self._help_test_nonempty_dict(dictutil.ValueOrderedDict)
1104         self._help_test_eq_but_notis(dictutil.UtilDict)
1105         self._help_test_eq_but_notis(dictutil.NumDict)
1106         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1107
1108     def test_dict_of_sets(self):
1109         ds = dictutil.DictOfSets()
1110         ds.add(1, "a")
1111         ds.add(2, "b")
1112         ds.add(2, "b")
1113         ds.add(2, "c")
1114         self.failUnlessEqual(ds[1], set(["a"]))
1115         self.failUnlessEqual(ds[2], set(["b", "c"]))
1116         ds.discard(3, "d") # should not raise an exception
1117         ds.discard(2, "b")
1118         self.failUnlessEqual(ds[2], set(["c"]))
1119         ds.discard(2, "c")
1120         self.failIf(2 in ds)
1121
1122         ds.add(3, "f")
1123         ds2 = dictutil.DictOfSets()
1124         ds2.add(3, "f")
1125         ds2.add(3, "g")
1126         ds2.add(4, "h")
1127         ds.update(ds2)
1128         self.failUnlessEqual(ds[1], set(["a"]))
1129         self.failUnlessEqual(ds[3], set(["f", "g"]))
1130         self.failUnlessEqual(ds[4], set(["h"]))
1131
1132     def test_move(self):
1133         d1 = {1: "a", 2: "b"}
1134         d2 = {2: "c", 3: "d"}
1135         dictutil.move(1, d1, d2)
1136         self.failUnlessEqual(d1, {2: "b"})
1137         self.failUnlessEqual(d2, {1: "a", 2: "c", 3: "d"})
1138
1139         d1 = {1: "a", 2: "b"}
1140         d2 = {2: "c", 3: "d"}
1141         dictutil.move(2, d1, d2)
1142         self.failUnlessEqual(d1, {1: "a"})
1143         self.failUnlessEqual(d2, {2: "b", 3: "d"})
1144
1145         d1 = {1: "a", 2: "b"}
1146         d2 = {2: "c", 3: "d"}
1147         self.failUnlessRaises(KeyError, dictutil.move, 5, d1, d2, strict=True)
1148
1149     def test_subtract(self):
1150         d1 = {1: "a", 2: "b"}
1151         d2 = {2: "c", 3: "d"}
1152         d3 = dictutil.subtract(d1, d2)
1153         self.failUnlessEqual(d3, {1: "a"})
1154
1155         d1 = {1: "a", 2: "b"}
1156         d2 = {2: "c"}
1157         d3 = dictutil.subtract(d1, d2)
1158         self.failUnlessEqual(d3, {1: "a"})
1159
1160     def test_utildict(self):
1161         d = dictutil.UtilDict({1: "a", 2: "b"})
1162         d.del_if_present(1)
1163         d.del_if_present(3)
1164         self.failUnlessEqual(d, {2: "b"})
1165         def eq(a, b):
1166             return a == b
1167         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1168
1169         d = dictutil.UtilDict({1: "b", 2: "a"})
1170         self.failUnlessEqual(d.items_sorted_by_value(),
1171                              [(2, "a"), (1, "b")])
1172         self.failUnlessEqual(d.items_sorted_by_key(),
1173                              [(1, "b"), (2, "a")])
1174         self.failUnlessEqual(repr(d), "{1: 'b', 2: 'a'}")
1175         self.failUnless(1 in d)
1176
1177         d2 = dictutil.UtilDict({3: "c", 4: "d"})
1178         self.failUnless(d != d2)
1179         self.failUnless(d2 > d)
1180         self.failUnless(d2 >= d)
1181         self.failUnless(d <= d2)
1182         self.failUnless(d < d2)
1183         self.failUnlessEqual(d[1], "b")
1184         self.failUnlessEqual(sorted(list([k for k in d])), [1,2])
1185
1186         d3 = d.copy()
1187         self.failUnlessEqual(d, d3)
1188         self.failUnless(isinstance(d3, dictutil.UtilDict))
1189
1190         d4 = d.fromkeys([3,4], "e")
1191         self.failUnlessEqual(d4, {3: "e", 4: "e"})
1192
1193         self.failUnlessEqual(d.get(1), "b")
1194         self.failUnlessEqual(d.get(3), None)
1195         self.failUnlessEqual(d.get(3, "default"), "default")
1196         self.failUnlessEqual(sorted(list(d.items())),
1197                              [(1, "b"), (2, "a")])
1198         self.failUnlessEqual(sorted(list(d.iteritems())),
1199                              [(1, "b"), (2, "a")])
1200         self.failUnlessEqual(sorted(d.keys()), [1, 2])
1201         self.failUnlessEqual(sorted(d.values()), ["a", "b"])
1202         x = d.setdefault(1, "new")
1203         self.failUnlessEqual(x, "b")
1204         self.failUnlessEqual(d[1], "b")
1205         x = d.setdefault(3, "new")
1206         self.failUnlessEqual(x, "new")
1207         self.failUnlessEqual(d[3], "new")
1208         del d[3]
1209
1210         x = d.popitem()
1211         self.failUnless(x in [(1, "b"), (2, "a")])
1212         x = d.popitem()
1213         self.failUnless(x in [(1, "b"), (2, "a")])
1214         self.failUnlessRaises(KeyError, d.popitem)
1215
1216     def test_numdict(self):
1217         d = dictutil.NumDict({"a": 1, "b": 2})
1218
1219         d.add_num("a", 10, 5)
1220         d.add_num("c", 20, 5)
1221         d.add_num("d", 30)
1222         self.failUnlessEqual(d, {"a": 11, "b": 2, "c": 25, "d": 30})
1223
1224         d.subtract_num("a", 10)
1225         d.subtract_num("e", 10)
1226         d.subtract_num("f", 10, 15)
1227         self.failUnlessEqual(d, {"a": 1, "b": 2, "c": 25, "d": 30,
1228                                  "e": -10, "f": 5})
1229
1230         self.failUnlessEqual(d.sum(), sum([1, 2, 25, 30, -10, 5]))
1231
1232         d = dictutil.NumDict()
1233         d.inc("a")
1234         d.inc("a")
1235         d.inc("b", 5)
1236         self.failUnlessEqual(d, {"a": 2, "b": 6})
1237         d.dec("a")
1238         d.dec("c")
1239         d.dec("d", 5)
1240         self.failUnlessEqual(d, {"a": 1, "b": 6, "c": -1, "d": 4})
1241         self.failUnlessEqual(d.items_sorted_by_key(),
1242                              [("a", 1), ("b", 6), ("c", -1), ("d", 4)])
1243         self.failUnlessEqual(d.items_sorted_by_value(),
1244                              [("c", -1), ("a", 1), ("d", 4), ("b", 6)])
1245         self.failUnlessEqual(d.item_with_largest_value(), ("b", 6))
1246
1247         d = dictutil.NumDict({"a": 1, "b": 2})
1248         self.failUnlessEqual(repr(d), "{'a': 1, 'b': 2}")
1249         self.failUnless("a" in d)
1250
1251         d2 = dictutil.NumDict({"c": 3, "d": 4})
1252         self.failUnless(d != d2)
1253         self.failUnless(d2 > d)
1254         self.failUnless(d2 >= d)
1255         self.failUnless(d <= d2)
1256         self.failUnless(d < d2)
1257         self.failUnlessEqual(d["a"], 1)
1258         self.failUnlessEqual(sorted(list([k for k in d])), ["a","b"])
1259         def eq(a, b):
1260             return a == b
1261         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1262
1263         d3 = d.copy()
1264         self.failUnlessEqual(d, d3)
1265         self.failUnless(isinstance(d3, dictutil.NumDict))
1266
1267         d4 = d.fromkeys(["a","b"], 5)
1268         self.failUnlessEqual(d4, {"a": 5, "b": 5})
1269
1270         self.failUnlessEqual(d.get("a"), 1)
1271         self.failUnlessEqual(d.get("c"), 0)
1272         self.failUnlessEqual(d.get("c", 5), 5)
1273         self.failUnlessEqual(sorted(list(d.items())),
1274                              [("a", 1), ("b", 2)])
1275         self.failUnlessEqual(sorted(list(d.iteritems())),
1276                              [("a", 1), ("b", 2)])
1277         self.failUnlessEqual(sorted(d.keys()), ["a", "b"])
1278         self.failUnlessEqual(sorted(d.values()), [1, 2])
1279         self.failUnless(d.has_key("a"))
1280         self.failIf(d.has_key("c"))
1281
1282         x = d.setdefault("c", 3)
1283         self.failUnlessEqual(x, 3)
1284         self.failUnlessEqual(d["c"], 3)
1285         x = d.setdefault("c", 5)
1286         self.failUnlessEqual(x, 3)
1287         self.failUnlessEqual(d["c"], 3)
1288         del d["c"]
1289
1290         x = d.popitem()
1291         self.failUnless(x in [("a", 1), ("b", 2)])
1292         x = d.popitem()
1293         self.failUnless(x in [("a", 1), ("b", 2)])
1294         self.failUnlessRaises(KeyError, d.popitem)
1295
1296         d.update({"c": 3})
1297         d.update({"c": 4, "d": 5})
1298         self.failUnlessEqual(d, {"c": 4, "d": 5})
1299
1300     def test_del_if_present(self):
1301         d = {1: "a", 2: "b"}
1302         dictutil.del_if_present(d, 1)
1303         dictutil.del_if_present(d, 3)
1304         self.failUnlessEqual(d, {2: "b"})
1305
1306     def test_valueordereddict(self):
1307         d = dictutil.ValueOrderedDict()
1308         d["a"] = 3
1309         d["b"] = 2
1310         d["c"] = 1
1311
1312         self.failUnlessEqual(d, {"a": 3, "b": 2, "c": 1})
1313         self.failUnlessEqual(d.items(), [("c", 1), ("b", 2), ("a", 3)])
1314         self.failUnlessEqual(d.values(), [1, 2, 3])
1315         self.failUnlessEqual(d.keys(), ["c", "b", "a"])
1316         self.failUnlessEqual(repr(d), "<ValueOrderedDict {c: 1, b: 2, a: 3}>")
1317         def eq(a, b):
1318             return a == b
1319         self.failIf(d == {"a": 4})
1320         self.failUnless(d != {"a": 4})
1321
1322         x = d.setdefault("d", 0)
1323         self.failUnlessEqual(x, 0)
1324         self.failUnlessEqual(d["d"], 0)
1325         x = d.setdefault("d", -1)
1326         self.failUnlessEqual(x, 0)
1327         self.failUnlessEqual(d["d"], 0)
1328
1329         x = d.remove("e", "default", False)
1330         self.failUnlessEqual(x, "default")
1331         self.failUnlessRaises(KeyError, d.remove, "e", "default", True)
1332         x = d.remove("d", 5)
1333         self.failUnlessEqual(x, 0)
1334
1335         x = d.__getitem__("c")
1336         self.failUnlessEqual(x, 1)
1337         x = d.__getitem__("e", "default", False)
1338         self.failUnlessEqual(x, "default")
1339         self.failUnlessRaises(KeyError, d.__getitem__, "e", "default", True)
1340
1341         self.failUnlessEqual(d.popitem(), ("c", 1))
1342         self.failUnlessEqual(d.popitem(), ("b", 2))
1343         self.failUnlessEqual(d.popitem(), ("a", 3))
1344         self.failUnlessRaises(KeyError, d.popitem)
1345
1346         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1347         x = d.pop("d", "default", False)
1348         self.failUnlessEqual(x, "default")
1349         self.failUnlessRaises(KeyError, d.pop, "d", "default", True)
1350         x = d.pop("b")
1351         self.failUnlessEqual(x, 2)
1352         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1353
1354         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1355         x = d.pop_from_list(1) # pop the second item, b/2
1356         self.failUnlessEqual(x, "b")
1357         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1358
1359     def test_auxdict(self):
1360         d = dictutil.AuxValueDict()
1361         # we put the serialized form in the auxdata
1362         d.set_with_aux("key", ("filecap", "metadata"), "serialized")
1363
1364         self.failUnlessEqual(d.keys(), ["key"])
1365         self.failUnlessEqual(d["key"], ("filecap", "metadata"))
1366         self.failUnlessEqual(d.get_aux("key"), "serialized")
1367         def _get_missing(key):
1368             return d[key]
1369         self.failUnlessRaises(KeyError, _get_missing, "nonkey")
1370         self.failUnlessEqual(d.get("nonkey"), None)
1371         self.failUnlessEqual(d.get("nonkey", "nonvalue"), "nonvalue")
1372         self.failUnlessEqual(d.get_aux("nonkey"), None)
1373         self.failUnlessEqual(d.get_aux("nonkey", "nonvalue"), "nonvalue")
1374
1375         d["key"] = ("filecap2", "metadata2")
1376         self.failUnlessEqual(d["key"], ("filecap2", "metadata2"))
1377         self.failUnlessEqual(d.get_aux("key"), None)
1378
1379         d.set_with_aux("key2", "value2", "aux2")
1380         self.failUnlessEqual(sorted(d.keys()), ["key", "key2"])
1381         del d["key2"]
1382         self.failUnlessEqual(d.keys(), ["key"])
1383         self.failIf("key2" in d)
1384         self.failUnlessRaises(KeyError, _get_missing, "key2")
1385         self.failUnlessEqual(d.get("key2"), None)
1386         self.failUnlessEqual(d.get_aux("key2"), None)
1387         d["key2"] = "newvalue2"
1388         self.failUnlessEqual(d.get("key2"), "newvalue2")
1389         self.failUnlessEqual(d.get_aux("key2"), None)
1390
1391         d = dictutil.AuxValueDict({1:2,3:4})
1392         self.failUnlessEqual(sorted(d.keys()), [1,3])
1393         self.failUnlessEqual(d[1], 2)
1394         self.failUnlessEqual(d.get_aux(1), None)
1395
1396         d = dictutil.AuxValueDict([ (1,2), (3,4) ])
1397         self.failUnlessEqual(sorted(d.keys()), [1,3])
1398         self.failUnlessEqual(d[1], 2)
1399         self.failUnlessEqual(d.get_aux(1), None)
1400
1401         d = dictutil.AuxValueDict(one=1, two=2)
1402         self.failUnlessEqual(sorted(d.keys()), ["one","two"])
1403         self.failUnlessEqual(d["one"], 1)
1404         self.failUnlessEqual(d.get_aux("one"), None)
1405
1406 class Pipeline(unittest.TestCase):
1407     def pause(self, *args, **kwargs):
1408         d = defer.Deferred()
1409         self.calls.append( (d, args, kwargs) )
1410         return d
1411
1412     def failUnlessCallsAre(self, expected):
1413         #print self.calls
1414         #print expected
1415         self.failUnlessEqual(len(self.calls), len(expected), self.calls)
1416         for i,c in enumerate(self.calls):
1417             self.failUnlessEqual(c[1:], expected[i], str(i))
1418
1419     def test_basic(self):
1420         self.calls = []
1421         finished = []
1422         p = pipeline.Pipeline(100)
1423
1424         d = p.flush() # fires immediately
1425         d.addCallbacks(finished.append, log.err)
1426         self.failUnlessEqual(len(finished), 1)
1427         finished = []
1428
1429         d = p.add(10, self.pause, "one")
1430         # the call should start right away, and our return Deferred should
1431         # fire right away
1432         d.addCallbacks(finished.append, log.err)
1433         self.failUnlessEqual(len(finished), 1)
1434         self.failUnlessEqual(finished[0], None)
1435         self.failUnlessCallsAre([ ( ("one",) , {} ) ])
1436         self.failUnlessEqual(p.gauge, 10)
1437
1438         # pipeline: [one]
1439
1440         finished = []
1441         d = p.add(20, self.pause, "two", kw=2)
1442         # pipeline: [one, two]
1443
1444         # the call and the Deferred should fire right away
1445         d.addCallbacks(finished.append, log.err)
1446         self.failUnlessEqual(len(finished), 1)
1447         self.failUnlessEqual(finished[0], None)
1448         self.failUnlessCallsAre([ ( ("one",) , {} ),
1449                                   ( ("two",) , {"kw": 2} ),
1450                                   ])
1451         self.failUnlessEqual(p.gauge, 30)
1452
1453         self.calls[0][0].callback("one-result")
1454         # pipeline: [two]
1455         self.failUnlessEqual(p.gauge, 20)
1456
1457         finished = []
1458         d = p.add(90, self.pause, "three", "posarg1")
1459         # pipeline: [two, three]
1460         flushed = []
1461         fd = p.flush()
1462         fd.addCallbacks(flushed.append, log.err)
1463         self.failUnlessEqual(flushed, [])
1464
1465         # the call will be made right away, but the return Deferred will not,
1466         # because the pipeline is now full.
1467         d.addCallbacks(finished.append, log.err)
1468         self.failUnlessEqual(len(finished), 0)
1469         self.failUnlessCallsAre([ ( ("one",) , {} ),
1470                                   ( ("two",) , {"kw": 2} ),
1471                                   ( ("three", "posarg1"), {} ),
1472                                   ])
1473         self.failUnlessEqual(p.gauge, 110)
1474
1475         self.failUnlessRaises(pipeline.SingleFileError, p.add, 10, self.pause)
1476
1477         # retiring either call will unblock the pipeline, causing the #3
1478         # Deferred to fire
1479         self.calls[2][0].callback("three-result")
1480         # pipeline: [two]
1481
1482         self.failUnlessEqual(len(finished), 1)
1483         self.failUnlessEqual(finished[0], None)
1484         self.failUnlessEqual(flushed, [])
1485
1486         # retiring call#2 will finally allow the flush() Deferred to fire
1487         self.calls[1][0].callback("two-result")
1488         self.failUnlessEqual(len(flushed), 1)
1489
1490     def test_errors(self):
1491         self.calls = []
1492         p = pipeline.Pipeline(100)
1493
1494         d1 = p.add(200, self.pause, "one")
1495         d2 = p.flush()
1496
1497         finished = []
1498         d1.addBoth(finished.append)
1499         self.failUnlessEqual(finished, [])
1500
1501         flushed = []
1502         d2.addBoth(flushed.append)
1503         self.failUnlessEqual(flushed, [])
1504
1505         self.calls[0][0].errback(ValueError("oops"))
1506
1507         self.failUnlessEqual(len(finished), 1)
1508         f = finished[0]
1509         self.failUnless(isinstance(f, Failure))
1510         self.failUnless(f.check(pipeline.PipelineError))
1511         self.failUnlessIn("PipelineError", str(f.value))
1512         self.failUnlessIn("ValueError", str(f.value))
1513         r = repr(f.value)
1514         self.failUnless("ValueError" in r, r)
1515         f2 = f.value.error
1516         self.failUnless(f2.check(ValueError))
1517
1518         self.failUnlessEqual(len(flushed), 1)
1519         f = flushed[0]
1520         self.failUnless(isinstance(f, Failure))
1521         self.failUnless(f.check(pipeline.PipelineError))
1522         f2 = f.value.error
1523         self.failUnless(f2.check(ValueError))
1524
1525         # now that the pipeline is in the failed state, any new calls will
1526         # fail immediately
1527
1528         d3 = p.add(20, self.pause, "two")
1529
1530         finished = []
1531         d3.addBoth(finished.append)
1532         self.failUnlessEqual(len(finished), 1)
1533         f = finished[0]
1534         self.failUnless(isinstance(f, Failure))
1535         self.failUnless(f.check(pipeline.PipelineError))
1536         r = repr(f.value)
1537         self.failUnless("ValueError" in r, r)
1538         f2 = f.value.error
1539         self.failUnless(f2.check(ValueError))
1540
1541         d4 = p.flush()
1542         flushed = []
1543         d4.addBoth(flushed.append)
1544         self.failUnlessEqual(len(flushed), 1)
1545         f = flushed[0]
1546         self.failUnless(isinstance(f, Failure))
1547         self.failUnless(f.check(pipeline.PipelineError))
1548         f2 = f.value.error
1549         self.failUnless(f2.check(ValueError))
1550
1551     def test_errors2(self):
1552         self.calls = []
1553         p = pipeline.Pipeline(100)
1554
1555         d1 = p.add(10, self.pause, "one")
1556         d2 = p.add(20, self.pause, "two")
1557         d3 = p.add(30, self.pause, "three")
1558         d4 = p.flush()
1559
1560         # one call fails, then the second one succeeds: make sure
1561         # ExpandableDeferredList tolerates the second one
1562
1563         flushed = []
1564         d4.addBoth(flushed.append)
1565         self.failUnlessEqual(flushed, [])
1566
1567         self.calls[0][0].errback(ValueError("oops"))
1568         self.failUnlessEqual(len(flushed), 1)
1569         f = flushed[0]
1570         self.failUnless(isinstance(f, Failure))
1571         self.failUnless(f.check(pipeline.PipelineError))
1572         f2 = f.value.error
1573         self.failUnless(f2.check(ValueError))
1574
1575         self.calls[1][0].callback("two-result")
1576         self.calls[2][0].errback(ValueError("three-error"))
1577
1578         del d1,d2,d3,d4
1579
1580 class SampleError(Exception):
1581     pass
1582
1583 class Log(unittest.TestCase):
1584     def test_err(self):
1585         if not hasattr(self, "flushLoggedErrors"):
1586             # without flushLoggedErrors, we can't get rid of the
1587             # twisted.log.err that tahoe_log records, so we can't keep this
1588             # test from [ERROR]ing
1589             raise unittest.SkipTest("needs flushLoggedErrors from Twisted-2.5.0")
1590         try:
1591             raise SampleError("simple sample")
1592         except:
1593             f = Failure()
1594         tahoe_log.err(format="intentional sample error",
1595                       failure=f, level=tahoe_log.OPERATIONAL, umid="wO9UoQ")
1596         self.flushLoggedErrors(SampleError)
1597
1598
1599 class SimpleSpans:
1600     # this is a simple+inefficient form of util.spans.Spans . We compare the
1601     # behavior of this reference model against the real (efficient) form.
1602
1603     def __init__(self, _span_or_start=None, length=None):
1604         self._have = set()
1605         if length is not None:
1606             for i in range(_span_or_start, _span_or_start+length):
1607                 self._have.add(i)
1608         elif _span_or_start:
1609             for (start,length) in _span_or_start:
1610                 self.add(start, length)
1611
1612     def add(self, start, length):
1613         for i in range(start, start+length):
1614             self._have.add(i)
1615         return self
1616
1617     def remove(self, start, length):
1618         for i in range(start, start+length):
1619             self._have.discard(i)
1620         return self
1621
1622     def each(self):
1623         return sorted(self._have)
1624
1625     def __iter__(self):
1626         items = sorted(self._have)
1627         prevstart = None
1628         prevend = None
1629         for i in items:
1630             if prevstart is None:
1631                 prevstart = prevend = i
1632                 continue
1633             if i == prevend+1:
1634                 prevend = i
1635                 continue
1636             yield (prevstart, prevend-prevstart+1)
1637             prevstart = prevend = i
1638         if prevstart is not None:
1639             yield (prevstart, prevend-prevstart+1)
1640
1641     def __nonzero__(self): # this gets us bool()
1642         return self.len()
1643
1644     def len(self):
1645         return len(self._have)
1646
1647     def __add__(self, other):
1648         s = self.__class__(self)
1649         for (start, length) in other:
1650             s.add(start, length)
1651         return s
1652
1653     def __sub__(self, other):
1654         s = self.__class__(self)
1655         for (start, length) in other:
1656             s.remove(start, length)
1657         return s
1658
1659     def __iadd__(self, other):
1660         for (start, length) in other:
1661             self.add(start, length)
1662         return self
1663
1664     def __isub__(self, other):
1665         for (start, length) in other:
1666             self.remove(start, length)
1667         return self
1668
1669     def __and__(self, other):
1670         s = self.__class__()
1671         for i in other.each():
1672             if i in self._have:
1673                 s.add(i, 1)
1674         return s
1675
1676     def __contains__(self, (start,length)):
1677         for i in range(start, start+length):
1678             if i not in self._have:
1679                 return False
1680         return True
1681
1682 class ByteSpans(unittest.TestCase):
1683     def test_basic(self):
1684         s = Spans()
1685         self.failUnlessEqual(list(s), [])
1686         self.failIf(s)
1687         self.failIf((0,1) in s)
1688         self.failUnlessEqual(s.len(), 0)
1689
1690         s1 = Spans(3, 4) # 3,4,5,6
1691         self._check1(s1)
1692
1693         s1 = Spans(3L, 4L) # 3,4,5,6
1694         self._check1(s1)
1695
1696         s2 = Spans(s1)
1697         self._check1(s2)
1698
1699         s2.add(10,2) # 10,11
1700         self._check1(s1)
1701         self.failUnless((10,1) in s2)
1702         self.failIf((10,1) in s1)
1703         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
1704         self.failUnlessEqual(s2.len(), 6)
1705
1706         s2.add(15,2).add(20,2)
1707         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
1708         self.failUnlessEqual(s2.len(), 10)
1709
1710         s2.remove(4,3).remove(15,1)
1711         self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
1712         self.failUnlessEqual(s2.len(), 6)
1713
1714         s1 = SimpleSpans(3, 4) # 3 4 5 6
1715         s2 = SimpleSpans(5, 4) # 5 6 7 8
1716         i = s1 & s2
1717         self.failUnlessEqual(list(i.each()), [5, 6])
1718
1719     def _check1(self, s):
1720         self.failUnlessEqual(list(s), [(3,4)])
1721         self.failUnless(s)
1722         self.failUnlessEqual(s.len(), 4)
1723         self.failIf((0,1) in s)
1724         self.failUnless((3,4) in s)
1725         self.failUnless((3,1) in s)
1726         self.failUnless((5,2) in s)
1727         self.failUnless((6,1) in s)
1728         self.failIf((6,2) in s)
1729         self.failIf((7,1) in s)
1730         self.failUnlessEqual(list(s.each()), [3,4,5,6])
1731
1732     def test_large(self):
1733         s = Spans(4, 2**65) # don't do this with a SimpleSpans
1734         self.failUnlessEqual(list(s), [(4, 2**65)])
1735         self.failUnless(s)
1736         self.failUnlessEqual(s.len(), 2**65)
1737         self.failIf((0,1) in s)
1738         self.failUnless((4,2) in s)
1739         self.failUnless((2**65,2) in s)
1740
1741     def test_math(self):
1742         s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
1743         s2 = Spans(5, 3) # 5,6,7
1744         s3 = Spans(8, 4) # 8,9,10,11
1745
1746         s = s1 - s2
1747         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1748         s = s1 - s3
1749         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1750         s = s2 - s3
1751         self.failUnlessEqual(list(s.each()), [5,6,7])
1752         s = s1 & s2
1753         self.failUnlessEqual(list(s.each()), [5,6,7])
1754         s = s2 & s1
1755         self.failUnlessEqual(list(s.each()), [5,6,7])
1756         s = s1 & s3
1757         self.failUnlessEqual(list(s.each()), [8,9])
1758         s = s3 & s1
1759         self.failUnlessEqual(list(s.each()), [8,9])
1760         s = s2 & s3
1761         self.failUnlessEqual(list(s.each()), [])
1762         s = s3 & s2
1763         self.failUnlessEqual(list(s.each()), [])
1764         s = Spans() & s3
1765         self.failUnlessEqual(list(s.each()), [])
1766         s = s3 & Spans()
1767         self.failUnlessEqual(list(s.each()), [])
1768
1769         s = s1 + s2
1770         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1771         s = s1 + s3
1772         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1773         s = s2 + s3
1774         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1775
1776         s = Spans(s1)
1777         s -= s2
1778         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1779         s = Spans(s1)
1780         s -= s3
1781         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1782         s = Spans(s2)
1783         s -= s3
1784         self.failUnlessEqual(list(s.each()), [5,6,7])
1785
1786         s = Spans(s1)
1787         s += s2
1788         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1789         s = Spans(s1)
1790         s += s3
1791         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1792         s = Spans(s2)
1793         s += s3
1794         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1795
1796     def test_random(self):
1797         # attempt to increase coverage of corner cases by comparing behavior
1798         # of a simple-but-slow model implementation against the
1799         # complex-but-fast actual implementation, in a large number of random
1800         # operations
1801         S1 = SimpleSpans
1802         S2 = Spans
1803         s1 = S1(); s2 = S2()
1804         seed = ""
1805         def _create(subseed):
1806             ns1 = S1(); ns2 = S2()
1807             for i in range(10):
1808                 what = _hash(subseed+str(i)).hexdigest()
1809                 start = int(what[2:4], 16)
1810                 length = max(1,int(what[5:6], 16))
1811                 ns1.add(start, length); ns2.add(start, length)
1812             return ns1, ns2
1813
1814         #print
1815         for i in range(1000):
1816             what = _hash(seed+str(i)).hexdigest()
1817             op = what[0]
1818             subop = what[1]
1819             start = int(what[2:4], 16)
1820             length = max(1,int(what[5:6], 16))
1821             #print what
1822             if op in "0":
1823                 if subop in "01234":
1824                     s1 = S1(); s2 = S2()
1825                 elif subop in "5678":
1826                     s1 = S1(start, length); s2 = S2(start, length)
1827                 else:
1828                     s1 = S1(s1); s2 = S2(s2)
1829                 #print "s2 = %s" % s2.dump()
1830             elif op in "123":
1831                 #print "s2.add(%d,%d)" % (start, length)
1832                 s1.add(start, length); s2.add(start, length)
1833             elif op in "456":
1834                 #print "s2.remove(%d,%d)" % (start, length)
1835                 s1.remove(start, length); s2.remove(start, length)
1836             elif op in "78":
1837                 ns1, ns2 = _create(what[7:11])
1838                 #print "s2 + %s" % ns2.dump()
1839                 s1 = s1 + ns1; s2 = s2 + ns2
1840             elif op in "9a":
1841                 ns1, ns2 = _create(what[7:11])
1842                 #print "%s - %s" % (s2.dump(), ns2.dump())
1843                 s1 = s1 - ns1; s2 = s2 - ns2
1844             elif op in "bc":
1845                 ns1, ns2 = _create(what[7:11])
1846                 #print "s2 += %s" % ns2.dump()
1847                 s1 += ns1; s2 += ns2
1848             elif op in "de":
1849                 ns1, ns2 = _create(what[7:11])
1850                 #print "%s -= %s" % (s2.dump(), ns2.dump())
1851                 s1 -= ns1; s2 -= ns2
1852             else:
1853                 ns1, ns2 = _create(what[7:11])
1854                 #print "%s &= %s" % (s2.dump(), ns2.dump())
1855                 s1 = s1 & ns1; s2 = s2 & ns2
1856             #print "s2 now %s" % s2.dump()
1857             self.failUnlessEqual(list(s1.each()), list(s2.each()))
1858             self.failUnlessEqual(s1.len(), s2.len())
1859             self.failUnlessEqual(bool(s1), bool(s2))
1860             self.failUnlessEqual(list(s1), list(s2))
1861             for j in range(10):
1862                 what = _hash(what[12:14]+str(j)).hexdigest()
1863                 start = int(what[2:4], 16)
1864                 length = max(1, int(what[5:6], 16))
1865                 span = (start, length)
1866                 self.failUnlessEqual(bool(span in s1), bool(span in s2))
1867
1868
1869     # s()
1870     # s(start,length)
1871     # s(s0)
1872     # s.add(start,length) : returns s
1873     # s.remove(start,length)
1874     # s.each() -> list of byte offsets, mostly for testing
1875     # list(s) -> list of (start,length) tuples, one per span
1876     # (start,length) in s -> True if (start..start+length-1) are all members
1877     #  NOT equivalent to x in list(s)
1878     # s.len() -> number of bytes, for testing, bool(), and accounting/limiting
1879     # bool(s)  (__nonzeron__)
1880     # s = s1+s2, s1-s2, +=s1, -=s1
1881
1882     def test_overlap(self):
1883         for a in range(20):
1884             for b in range(10):
1885                 for c in range(20):
1886                     for d in range(10):
1887                         self._test_overlap(a,b,c,d)
1888
1889     def _test_overlap(self, a, b, c, d):
1890         s1 = set(range(a,a+b))
1891         s2 = set(range(c,c+d))
1892         #print "---"
1893         #self._show_overlap(s1, "1")
1894         #self._show_overlap(s2, "2")
1895         o = overlap(a,b,c,d)
1896         expected = s1.intersection(s2)
1897         if not expected:
1898             self.failUnlessEqual(o, None)
1899         else:
1900             start,length = o
1901             so = set(range(start,start+length))
1902             #self._show(so, "o")
1903             self.failUnlessEqual(so, expected)
1904
1905     def _show_overlap(self, s, c):
1906         import sys
1907         out = sys.stdout
1908         if s:
1909             for i in range(max(s)):
1910                 if i in s:
1911                     out.write(c)
1912                 else:
1913                     out.write(" ")
1914         out.write("\n")
1915
1916 def extend(s, start, length, fill):
1917     if len(s) >= start+length:
1918         return s
1919     assert len(fill) == 1
1920     return s + fill*(start+length-len(s))
1921
1922 def replace(s, start, data):
1923     assert len(s) >= start+len(data)
1924     return s[:start] + data + s[start+len(data):]
1925
1926 class SimpleDataSpans:
1927     def __init__(self, other=None):
1928         self.missing = "" # "1" where missing, "0" where found
1929         self.data = ""
1930         if other:
1931             for (start, data) in other.get_chunks():
1932                 self.add(start, data)
1933
1934     def __nonzero__(self): # this gets us bool()
1935         return self.len()
1936     def len(self):
1937         return len(self.missing.replace("1", ""))
1938     def _dump(self):
1939         return [i for (i,c) in enumerate(self.missing) if c == "0"]
1940     def _have(self, start, length):
1941         m = self.missing[start:start+length]
1942         if not m or len(m)<length or int(m):
1943             return False
1944         return True
1945     def get_chunks(self):
1946         for i in self._dump():
1947             yield (i, self.data[i])
1948     def get_spans(self):
1949         return SimpleSpans([(start,len(data))
1950                             for (start,data) in self.get_chunks()])
1951     def get(self, start, length):
1952         if self._have(start, length):
1953             return self.data[start:start+length]
1954         return None
1955     def pop(self, start, length):
1956         data = self.get(start, length)
1957         if data:
1958             self.remove(start, length)
1959         return data
1960     def remove(self, start, length):
1961         self.missing = replace(extend(self.missing, start, length, "1"),
1962                                start, "1"*length)
1963     def add(self, start, data):
1964         self.missing = replace(extend(self.missing, start, len(data), "1"),
1965                                start, "0"*len(data))
1966         self.data = replace(extend(self.data, start, len(data), " "),
1967                             start, data)
1968
1969
1970 class StringSpans(unittest.TestCase):
1971     def do_basic(self, klass):
1972         ds = klass()
1973         self.failUnlessEqual(ds.len(), 0)
1974         self.failUnlessEqual(list(ds._dump()), [])
1975         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 0)
1976         s = ds.get_spans()
1977         self.failUnlessEqual(ds.get(0, 4), None)
1978         self.failUnlessEqual(ds.pop(0, 4), None)
1979         ds.remove(0, 4)
1980
1981         ds.add(2, "four")
1982         self.failUnlessEqual(ds.len(), 4)
1983         self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
1984         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1985         s = ds.get_spans()
1986         self.failUnless((2,2) in s)
1987         self.failUnlessEqual(ds.get(0, 4), None)
1988         self.failUnlessEqual(ds.pop(0, 4), None)
1989         self.failUnlessEqual(ds.get(4, 4), None)
1990
1991         ds2 = klass(ds)
1992         self.failUnlessEqual(ds2.len(), 4)
1993         self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
1994         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 4)
1995         self.failUnlessEqual(ds2.get(0, 4), None)
1996         self.failUnlessEqual(ds2.pop(0, 4), None)
1997         self.failUnlessEqual(ds2.pop(2, 3), "fou")
1998         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 1)
1999         self.failUnlessEqual(ds2.get(2, 3), None)
2000         self.failUnlessEqual(ds2.get(5, 1), "r")
2001         self.failUnlessEqual(ds.get(2, 3), "fou")
2002         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
2003
2004         ds.add(0, "23")
2005         self.failUnlessEqual(ds.len(), 6)
2006         self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
2007         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 6)
2008         self.failUnlessEqual(ds.get(0, 4), "23fo")
2009         self.failUnlessEqual(ds.pop(0, 4), "23fo")
2010         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 2)
2011         self.failUnlessEqual(ds.get(0, 4), None)
2012         self.failUnlessEqual(ds.pop(0, 4), None)
2013
2014         ds = klass()
2015         ds.add(2, "four")
2016         ds.add(3, "ea")
2017         self.failUnlessEqual(ds.get(2, 4), "fear")
2018
2019         ds = klass()
2020         ds.add(2L, "four")
2021         ds.add(3L, "ea")
2022         self.failUnlessEqual(ds.get(2L, 4L), "fear")
2023
2024
2025     def do_scan(self, klass):
2026         # do a test with gaps and spans of size 1 and 2
2027         #  left=(1,11) * right=(1,11) * gapsize=(1,2)
2028         # 111, 112, 121, 122, 211, 212, 221, 222
2029         #    211
2030         #      121
2031         #         112
2032         #            212
2033         #               222
2034         #                   221
2035         #                      111
2036         #                        122
2037         #  11 1  1 11 11  11  1 1  111
2038         # 0123456789012345678901234567
2039         # abcdefghijklmnopqrstuvwxyz-=
2040         pieces = [(1, "bc"),
2041                   (4, "e"),
2042                   (7, "h"),
2043                   (9, "jk"),
2044                   (12, "mn"),
2045                   (16, "qr"),
2046                   (20, "u"),
2047                   (22, "w"),
2048                   (25, "z-="),
2049                   ]
2050         p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
2051         S = "abcdefghijklmnopqrstuvwxyz-="
2052         # TODO: when adding data, add capital letters, to make sure we aren't
2053         # just leaving the old data in place
2054         l = len(S)
2055         def base():
2056             ds = klass()
2057             for start, data in pieces:
2058                 ds.add(start, data)
2059             return ds
2060         def dump(s):
2061             p = set(s._dump())
2062             d = "".join([((i not in p) and " " or S[i]) for i in range(l)])
2063             assert len(d) == l
2064             return d
2065         DEBUG = False
2066         for start in range(0, l):
2067             for end in range(start+1, l):
2068                 # add [start-end) to the baseline
2069                 which = "%d-%d" % (start, end-1)
2070                 p_added = set(range(start, end))
2071                 b = base()
2072                 if DEBUG:
2073                     print
2074                     print dump(b), which
2075                     add = klass(); add.add(start, S[start:end])
2076                     print dump(add)
2077                 b.add(start, S[start:end])
2078                 if DEBUG:
2079                     print dump(b)
2080                 # check that the new span is there
2081                 d = b.get(start, end-start)
2082                 self.failUnlessEqual(d, S[start:end], which)
2083                 # check that all the original pieces are still there
2084                 for t_start, t_data in pieces:
2085                     t_len = len(t_data)
2086                     self.failUnlessEqual(b.get(t_start, t_len),
2087                                          S[t_start:t_start+t_len],
2088                                          "%s %d+%d" % (which, t_start, t_len))
2089                 # check that a lot of subspans are mostly correct
2090                 for t_start in range(l):
2091                     for t_len in range(1,4):
2092                         d = b.get(t_start, t_len)
2093                         if d is not None:
2094                             which2 = "%s+(%d-%d)" % (which, t_start,
2095                                                      t_start+t_len-1)
2096                             self.failUnlessEqual(d, S[t_start:t_start+t_len],
2097                                                  which2)
2098                         # check that removing a subspan gives the right value
2099                         b2 = klass(b)
2100                         b2.remove(t_start, t_len)
2101                         removed = set(range(t_start, t_start+t_len))
2102                         for i in range(l):
2103                             exp = (((i in p_elements) or (i in p_added))
2104                                    and (i not in removed))
2105                             which2 = "%s-(%d-%d)" % (which, t_start,
2106                                                      t_start+t_len-1)
2107                             self.failUnlessEqual(bool(b2.get(i, 1)), exp,
2108                                                  which2+" %d" % i)
2109
2110     def test_test(self):
2111         self.do_basic(SimpleDataSpans)
2112         self.do_scan(SimpleDataSpans)
2113
2114     def test_basic(self):
2115         self.do_basic(DataSpans)
2116         self.do_scan(DataSpans)
2117
2118     def test_random(self):
2119         # attempt to increase coverage of corner cases by comparing behavior
2120         # of a simple-but-slow model implementation against the
2121         # complex-but-fast actual implementation, in a large number of random
2122         # operations
2123         S1 = SimpleDataSpans
2124         S2 = DataSpans
2125         s1 = S1(); s2 = S2()
2126         seed = ""
2127         def _randstr(length, seed):
2128             created = 0
2129             pieces = []
2130             while created < length:
2131                 piece = _hash(seed + str(created)).hexdigest()
2132                 pieces.append(piece)
2133                 created += len(piece)
2134             return "".join(pieces)[:length]
2135         def _create(subseed):
2136             ns1 = S1(); ns2 = S2()
2137             for i in range(10):
2138                 what = _hash(subseed+str(i)).hexdigest()
2139                 start = int(what[2:4], 16)
2140                 length = max(1,int(what[5:6], 16))
2141                 ns1.add(start, _randstr(length, what[7:9]));
2142                 ns2.add(start, _randstr(length, what[7:9]))
2143             return ns1, ns2
2144
2145         #print
2146         for i in range(1000):
2147             what = _hash(seed+str(i)).hexdigest()
2148             op = what[0]
2149             subop = what[1]
2150             start = int(what[2:4], 16)
2151             length = max(1,int(what[5:6], 16))
2152             #print what
2153             if op in "0":
2154                 if subop in "0123456":
2155                     s1 = S1(); s2 = S2()
2156                 else:
2157                     s1, s2 = _create(what[7:11])
2158                 #print "s2 = %s" % list(s2._dump())
2159             elif op in "123456":
2160                 #print "s2.add(%d,%d)" % (start, length)
2161                 s1.add(start, _randstr(length, what[7:9]));
2162                 s2.add(start, _randstr(length, what[7:9]))
2163             elif op in "789abc":
2164                 #print "s2.remove(%d,%d)" % (start, length)
2165                 s1.remove(start, length); s2.remove(start, length)
2166             else:
2167                 #print "s2.pop(%d,%d)" % (start, length)
2168                 d1 = s1.pop(start, length); d2 = s2.pop(start, length)
2169                 self.failUnlessEqual(d1, d2)
2170             #print "s1 now %s" % list(s1._dump())
2171             #print "s2 now %s" % list(s2._dump())
2172             self.failUnlessEqual(s1.len(), s2.len())
2173             self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
2174             for j in range(100):
2175                 what = _hash(what[12:14]+str(j)).hexdigest()
2176                 start = int(what[2:4], 16)
2177                 length = max(1, int(what[5:6], 16))
2178                 d1 = s1.get(start, length); d2 = s2.get(start, length)
2179                 self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))