]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_util.py
Rename 'constant_time_compare' to 'timing_safe_compare'. refs #2165
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_util.py
1
2 def foo(): pass # keep the line number constant
3
4 import os, time, sys
5 from StringIO import StringIO
6 from twisted.trial import unittest
7 from twisted.internet import defer, reactor
8 from twisted.python.failure import Failure
9 from twisted.python import log
10 from pycryptopp.hash.sha256 import SHA256 as _hash
11
12 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
13 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
14 from allmydata.util import limiter, time_format, pollmixin, cachedir
15 from allmydata.util import statistics, dictutil, pipeline
16 from allmydata.util import log as tahoe_log
17 from allmydata.util.spans import Spans, overlap, DataSpans
18
19 class Base32(unittest.TestCase):
20     def test_b2a_matches_Pythons(self):
21         import base64
22         y = "\x12\x34\x45\x67\x89\x0a\xbc\xde\xf0"
23         x = base64.b32encode(y)
24         while x and x[-1] == '=':
25             x = x[:-1]
26         x = x.lower()
27         self.failUnlessEqual(base32.b2a(y), x)
28     def test_b2a(self):
29         self.failUnlessEqual(base32.b2a("\x12\x34"), "ci2a")
30     def test_b2a_or_none(self):
31         self.failUnlessEqual(base32.b2a_or_none(None), None)
32         self.failUnlessEqual(base32.b2a_or_none("\x12\x34"), "ci2a")
33     def test_a2b(self):
34         self.failUnlessEqual(base32.a2b("ci2a"), "\x12\x34")
35         self.failUnlessRaises(AssertionError, base32.a2b, "b0gus")
36
37 class IDLib(unittest.TestCase):
38     def test_nodeid_b2a(self):
39         self.failUnlessEqual(idlib.nodeid_b2a("\x00"*20), "a"*32)
40
41 class NoArgumentException(Exception):
42     def __init__(self):
43         pass
44
45 class HumanReadable(unittest.TestCase):
46     def test_repr(self):
47         hr = humanreadable.hr
48         self.failUnlessEqual(hr(foo), "<foo() at test_util.py:2>")
49         self.failUnlessEqual(hr(self.test_repr),
50                              "<bound method HumanReadable.test_repr of <allmydata.test.test_util.HumanReadable testMethod=test_repr>>")
51         self.failUnlessEqual(hr(1L), "1")
52         self.failUnlessEqual(hr(10**40),
53                              "100000000000000000...000000000000000000")
54         self.failUnlessEqual(hr(self), "<allmydata.test.test_util.HumanReadable testMethod=test_repr>")
55         self.failUnlessEqual(hr([1,2]), "[1, 2]")
56         self.failUnlessEqual(hr({1:2}), "{1:2}")
57         try:
58             raise ValueError
59         except Exception, e:
60             self.failUnless(
61                 hr(e) == "<ValueError: ()>" # python-2.4
62                 or hr(e) == "ValueError()") # python-2.5
63         try:
64             raise ValueError("oops")
65         except Exception, e:
66             self.failUnless(
67                 hr(e) == "<ValueError: 'oops'>" # python-2.4
68                 or hr(e) == "ValueError('oops',)") # python-2.5
69         try:
70             raise NoArgumentException
71         except Exception, e:
72             self.failUnless(
73                 hr(e) == "<NoArgumentException>" # python-2.4
74                 or hr(e) == "NoArgumentException()") # python-2.5
75
76
77 class MyList(list):
78     pass
79
80 class Math(unittest.TestCase):
81     def test_div_ceil(self):
82         f = mathutil.div_ceil
83         self.failUnlessEqual(f(0, 1), 0)
84         self.failUnlessEqual(f(0, 2), 0)
85         self.failUnlessEqual(f(0, 3), 0)
86         self.failUnlessEqual(f(1, 3), 1)
87         self.failUnlessEqual(f(2, 3), 1)
88         self.failUnlessEqual(f(3, 3), 1)
89         self.failUnlessEqual(f(4, 3), 2)
90         self.failUnlessEqual(f(5, 3), 2)
91         self.failUnlessEqual(f(6, 3), 2)
92         self.failUnlessEqual(f(7, 3), 3)
93
94     def test_next_multiple(self):
95         f = mathutil.next_multiple
96         self.failUnlessEqual(f(5, 1), 5)
97         self.failUnlessEqual(f(5, 2), 6)
98         self.failUnlessEqual(f(5, 3), 6)
99         self.failUnlessEqual(f(5, 4), 8)
100         self.failUnlessEqual(f(5, 5), 5)
101         self.failUnlessEqual(f(5, 6), 6)
102         self.failUnlessEqual(f(32, 1), 32)
103         self.failUnlessEqual(f(32, 2), 32)
104         self.failUnlessEqual(f(32, 3), 33)
105         self.failUnlessEqual(f(32, 4), 32)
106         self.failUnlessEqual(f(32, 5), 35)
107         self.failUnlessEqual(f(32, 6), 36)
108         self.failUnlessEqual(f(32, 7), 35)
109         self.failUnlessEqual(f(32, 8), 32)
110         self.failUnlessEqual(f(32, 9), 36)
111         self.failUnlessEqual(f(32, 10), 40)
112         self.failUnlessEqual(f(32, 11), 33)
113         self.failUnlessEqual(f(32, 12), 36)
114         self.failUnlessEqual(f(32, 13), 39)
115         self.failUnlessEqual(f(32, 14), 42)
116         self.failUnlessEqual(f(32, 15), 45)
117         self.failUnlessEqual(f(32, 16), 32)
118         self.failUnlessEqual(f(32, 17), 34)
119         self.failUnlessEqual(f(32, 18), 36)
120         self.failUnlessEqual(f(32, 589), 589)
121
122     def test_pad_size(self):
123         f = mathutil.pad_size
124         self.failUnlessEqual(f(0, 4), 0)
125         self.failUnlessEqual(f(1, 4), 3)
126         self.failUnlessEqual(f(2, 4), 2)
127         self.failUnlessEqual(f(3, 4), 1)
128         self.failUnlessEqual(f(4, 4), 0)
129         self.failUnlessEqual(f(5, 4), 3)
130
131     def test_is_power_of_k(self):
132         f = mathutil.is_power_of_k
133         for i in range(1, 100):
134             if i in (1, 2, 4, 8, 16, 32, 64):
135                 self.failUnless(f(i, 2), "but %d *is* a power of 2" % i)
136             else:
137                 self.failIf(f(i, 2), "but %d is *not* a power of 2" % i)
138         for i in range(1, 100):
139             if i in (1, 3, 9, 27, 81):
140                 self.failUnless(f(i, 3), "but %d *is* a power of 3" % i)
141             else:
142                 self.failIf(f(i, 3), "but %d is *not* a power of 3" % i)
143
144     def test_next_power_of_k(self):
145         f = mathutil.next_power_of_k
146         self.failUnlessEqual(f(0,2), 1)
147         self.failUnlessEqual(f(1,2), 1)
148         self.failUnlessEqual(f(2,2), 2)
149         self.failUnlessEqual(f(3,2), 4)
150         self.failUnlessEqual(f(4,2), 4)
151         for i in range(5, 8): self.failUnlessEqual(f(i,2), 8, "%d" % i)
152         for i in range(9, 16): self.failUnlessEqual(f(i,2), 16, "%d" % i)
153         for i in range(17, 32): self.failUnlessEqual(f(i,2), 32, "%d" % i)
154         for i in range(33, 64): self.failUnlessEqual(f(i,2), 64, "%d" % i)
155         for i in range(65, 100): self.failUnlessEqual(f(i,2), 128, "%d" % i)
156
157         self.failUnlessEqual(f(0,3), 1)
158         self.failUnlessEqual(f(1,3), 1)
159         self.failUnlessEqual(f(2,3), 3)
160         self.failUnlessEqual(f(3,3), 3)
161         for i in range(4, 9): self.failUnlessEqual(f(i,3), 9, "%d" % i)
162         for i in range(10, 27): self.failUnlessEqual(f(i,3), 27, "%d" % i)
163         for i in range(28, 81): self.failUnlessEqual(f(i,3), 81, "%d" % i)
164         for i in range(82, 200): self.failUnlessEqual(f(i,3), 243, "%d" % i)
165
166     def test_ave(self):
167         f = mathutil.ave
168         self.failUnlessEqual(f([1,2,3]), 2)
169         self.failUnlessEqual(f([0,0,0,4]), 1)
170         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
171
172     def test_round_sigfigs(self):
173         f = mathutil.round_sigfigs
174         self.failUnlessEqual(f(22.0/3, 4), 7.3330000000000002)
175
176 class Statistics(unittest.TestCase):
177     def should_assert(self, msg, func, *args, **kwargs):
178         try:
179             func(*args, **kwargs)
180             self.fail(msg)
181         except AssertionError:
182             pass
183
184     def failUnlessListEqual(self, a, b, msg = None):
185         self.failUnlessEqual(len(a), len(b))
186         for i in range(len(a)):
187             self.failUnlessEqual(a[i], b[i], msg)
188
189     def failUnlessListAlmostEqual(self, a, b, places = 7, msg = None):
190         self.failUnlessEqual(len(a), len(b))
191         for i in range(len(a)):
192             self.failUnlessAlmostEqual(a[i], b[i], places, msg)
193
194     def test_binomial_coeff(self):
195         f = statistics.binomial_coeff
196         self.failUnlessEqual(f(20, 0), 1)
197         self.failUnlessEqual(f(20, 1), 20)
198         self.failUnlessEqual(f(20, 2), 190)
199         self.failUnlessEqual(f(20, 8), f(20, 12))
200         self.should_assert("Should assert if n < k", f, 2, 3)
201
202     def test_binomial_distribution_pmf(self):
203         f = statistics.binomial_distribution_pmf
204
205         pmf_comp = f(2, .1)
206         pmf_stat = [0.81, 0.18, 0.01]
207         self.failUnlessListAlmostEqual(pmf_comp, pmf_stat)
208
209         # Summing across a PMF should give the total probability 1
210         self.failUnlessAlmostEqual(sum(pmf_comp), 1)
211         self.should_assert("Should assert if not 0<=p<=1", f, 1, -1)
212         self.should_assert("Should assert if n < 1", f, 0, .1)
213
214         out = StringIO()
215         statistics.print_pmf(pmf_comp, out=out)
216         lines = out.getvalue().splitlines()
217         self.failUnlessEqual(lines[0], "i=0: 0.81")
218         self.failUnlessEqual(lines[1], "i=1: 0.18")
219         self.failUnlessEqual(lines[2], "i=2: 0.01")
220
221     def test_survival_pmf(self):
222         f = statistics.survival_pmf
223         # Cross-check binomial-distribution method against convolution
224         # method.
225         p_list = [.9999] * 100 + [.99] * 50 + [.8] * 20
226         pmf1 = statistics.survival_pmf_via_conv(p_list)
227         pmf2 = statistics.survival_pmf_via_bd(p_list)
228         self.failUnlessListAlmostEqual(pmf1, pmf2)
229         self.failUnlessTrue(statistics.valid_pmf(pmf1))
230         self.should_assert("Should assert if p_i > 1", f, [1.1]);
231         self.should_assert("Should assert if p_i < 0", f, [-.1]);
232
233     def test_repair_count_pmf(self):
234         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
235         repair_pmf = statistics.repair_count_pmf(survival_pmf, 3)
236         # repair_pmf[0] == sum(survival_pmf[0,1,2,5])
237         # repair_pmf[1] == survival_pmf[4]
238         # repair_pmf[2] = survival_pmf[3]
239         self.failUnlessListAlmostEqual(repair_pmf,
240                                        [0.00001 + 0.00045 + 0.0081 + 0.59049,
241                                         .32805,
242                                         .0729,
243                                         0, 0, 0])
244
245     def test_repair_cost(self):
246         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
247         bwcost = statistics.bandwidth_cost_function
248         cost = statistics.mean_repair_cost(bwcost, 1000,
249                                            survival_pmf, 3, ul_dl_ratio=1.0)
250         self.failUnlessAlmostEqual(cost, 558.90)
251         cost = statistics.mean_repair_cost(bwcost, 1000,
252                                            survival_pmf, 3, ul_dl_ratio=8.0)
253         self.failUnlessAlmostEqual(cost, 1664.55)
254
255         # I haven't manually checked the math beyond here -warner
256         cost = statistics.eternal_repair_cost(bwcost, 1000,
257                                               survival_pmf, 3,
258                                               discount_rate=0, ul_dl_ratio=1.0)
259         self.failUnlessAlmostEqual(cost, 65292.056074766246)
260         cost = statistics.eternal_repair_cost(bwcost, 1000,
261                                               survival_pmf, 3,
262                                               discount_rate=0.05,
263                                               ul_dl_ratio=1.0)
264         self.failUnlessAlmostEqual(cost, 9133.6097158191551)
265
266     def test_convolve(self):
267         f = statistics.convolve
268         v1 = [ 1, 2, 3 ]
269         v2 = [ 4, 5, 6 ]
270         v3 = [ 7, 8 ]
271         v1v2result = [ 4, 13, 28, 27, 18 ]
272         # Convolution is commutative
273         r1 = f(v1, v2)
274         r2 = f(v2, v1)
275         self.failUnlessListEqual(r1, r2, "Convolution should be commutative")
276         self.failUnlessListEqual(r1, v1v2result, "Didn't match known result")
277         # Convolution is associative
278         r1 = f(f(v1, v2), v3)
279         r2 = f(v1, f(v2, v3))
280         self.failUnlessListEqual(r1, r2, "Convolution should be associative")
281         # Convolution is distributive
282         r1 = f(v3, [ a + b for a, b in zip(v1, v2) ])
283         tmp1 = f(v3, v1)
284         tmp2 = f(v3, v2)
285         r2 = [ a + b for a, b in zip(tmp1, tmp2) ]
286         self.failUnlessListEqual(r1, r2, "Convolution should be distributive")
287         # Convolution is scalar multiplication associative
288         tmp1 = f(v1, v2)
289         r1 = [ a * 4 for a in tmp1 ]
290         tmp2 = [ a * 4 for a in v1 ]
291         r2 = f(tmp2, v2)
292         self.failUnlessListEqual(r1, r2, "Convolution should be scalar multiplication associative")
293
294     def test_find_k(self):
295         f = statistics.find_k
296         g = statistics.pr_file_loss
297         plist = [.9] * 10 + [.8] * 10 # N=20
298         t = .0001
299         k = f(plist, t)
300         self.failUnlessEqual(k, 10)
301         self.failUnless(g(plist, k) < t)
302
303     def test_pr_file_loss(self):
304         f = statistics.pr_file_loss
305         plist = [.5] * 10
306         self.failUnlessEqual(f(plist, 3), .0546875)
307
308     def test_pr_backup_file_loss(self):
309         f = statistics.pr_backup_file_loss
310         plist = [.5] * 10
311         self.failUnlessEqual(f(plist, .5, 3), .02734375)
312
313
314 class Asserts(unittest.TestCase):
315     def should_assert(self, func, *args, **kwargs):
316         try:
317             func(*args, **kwargs)
318         except AssertionError, e:
319             return str(e)
320         except Exception, e:
321             self.fail("assert failed with non-AssertionError: %s" % e)
322         self.fail("assert was not caught")
323
324     def should_not_assert(self, func, *args, **kwargs):
325         try:
326             func(*args, **kwargs)
327         except AssertionError, e:
328             self.fail("assertion fired when it should not have: %s" % e)
329         except Exception, e:
330             self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
331         return # we're happy
332
333
334     def test_assert(self):
335         f = assertutil._assert
336         self.should_assert(f)
337         self.should_assert(f, False)
338         self.should_not_assert(f, True)
339
340         m = self.should_assert(f, False, "message")
341         self.failUnlessEqual(m, "'message' <type 'str'>", m)
342         m = self.should_assert(f, False, "message1", othermsg=12)
343         self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
344         m = self.should_assert(f, False, othermsg="message2")
345         self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
346
347     def test_precondition(self):
348         f = assertutil.precondition
349         self.should_assert(f)
350         self.should_assert(f, False)
351         self.should_not_assert(f, True)
352
353         m = self.should_assert(f, False, "message")
354         self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
355         m = self.should_assert(f, False, "message1", othermsg=12)
356         self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
357         m = self.should_assert(f, False, othermsg="message2")
358         self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
359
360     def test_postcondition(self):
361         f = assertutil.postcondition
362         self.should_assert(f)
363         self.should_assert(f, False)
364         self.should_not_assert(f, True)
365
366         m = self.should_assert(f, False, "message")
367         self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
368         m = self.should_assert(f, False, "message1", othermsg=12)
369         self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
370         m = self.should_assert(f, False, othermsg="message2")
371         self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
372
373 class FileUtil(unittest.TestCase):
374     def mkdir(self, basedir, path, mode=0777):
375         fn = os.path.join(basedir, path)
376         fileutil.make_dirs(fn, mode)
377
378     def touch(self, basedir, path, mode=None, data="touch\n"):
379         fn = os.path.join(basedir, path)
380         f = open(fn, "w")
381         f.write(data)
382         f.close()
383         if mode is not None:
384             os.chmod(fn, mode)
385
386     def test_rm_dir(self):
387         basedir = "util/FileUtil/test_rm_dir"
388         fileutil.make_dirs(basedir)
389         # create it again to test idempotency
390         fileutil.make_dirs(basedir)
391         d = os.path.join(basedir, "doomed")
392         self.mkdir(d, "a/b")
393         self.touch(d, "a/b/1.txt")
394         self.touch(d, "a/b/2.txt", 0444)
395         self.touch(d, "a/b/3.txt", 0)
396         self.mkdir(d, "a/c")
397         self.touch(d, "a/c/1.txt")
398         self.touch(d, "a/c/2.txt", 0444)
399         self.touch(d, "a/c/3.txt", 0)
400         os.chmod(os.path.join(d, "a/c"), 0444)
401         self.mkdir(d, "a/d")
402         self.touch(d, "a/d/1.txt")
403         self.touch(d, "a/d/2.txt", 0444)
404         self.touch(d, "a/d/3.txt", 0)
405         os.chmod(os.path.join(d, "a/d"), 0)
406
407         fileutil.rm_dir(d)
408         self.failIf(os.path.exists(d))
409         # remove it again to test idempotency
410         fileutil.rm_dir(d)
411
412     def test_remove_if_possible(self):
413         basedir = "util/FileUtil/test_remove_if_possible"
414         fileutil.make_dirs(basedir)
415         self.touch(basedir, "here")
416         fn = os.path.join(basedir, "here")
417         fileutil.remove_if_possible(fn)
418         self.failIf(os.path.exists(fn))
419         fileutil.remove_if_possible(fn) # should be idempotent
420         fileutil.rm_dir(basedir)
421         fileutil.remove_if_possible(fn) # should survive errors
422
423     def test_write_atomically(self):
424         basedir = "util/FileUtil/test_write_atomically"
425         fileutil.make_dirs(basedir)
426         fn = os.path.join(basedir, "here")
427         fileutil.write_atomically(fn, "one")
428         self.failUnlessEqual(fileutil.read(fn), "one")
429         fileutil.write_atomically(fn, "two", mode="") # non-binary
430         self.failUnlessEqual(fileutil.read(fn), "two")
431
432     def test_open_or_create(self):
433         basedir = "util/FileUtil/test_open_or_create"
434         fileutil.make_dirs(basedir)
435         fn = os.path.join(basedir, "here")
436         f = fileutil.open_or_create(fn)
437         f.write("stuff.")
438         f.close()
439         f = fileutil.open_or_create(fn)
440         f.seek(0, os.SEEK_END)
441         f.write("more.")
442         f.close()
443         f = open(fn, "r")
444         data = f.read()
445         f.close()
446         self.failUnlessEqual(data, "stuff.more.")
447
448     def test_NamedTemporaryDirectory(self):
449         basedir = "util/FileUtil/test_NamedTemporaryDirectory"
450         fileutil.make_dirs(basedir)
451         td = fileutil.NamedTemporaryDirectory(dir=basedir)
452         name = td.name
453         self.failUnless(basedir in name)
454         self.failUnless(basedir in repr(td))
455         self.failUnless(os.path.isdir(name))
456         del td
457         # it is conceivable that we need to force gc here, but I'm not sure
458         self.failIf(os.path.isdir(name))
459
460     def test_rename(self):
461         basedir = "util/FileUtil/test_rename"
462         fileutil.make_dirs(basedir)
463         self.touch(basedir, "here")
464         fn = os.path.join(basedir, "here")
465         fn2 = os.path.join(basedir, "there")
466         fileutil.rename(fn, fn2)
467         self.failIf(os.path.exists(fn))
468         self.failUnless(os.path.exists(fn2))
469
470     def test_du(self):
471         basedir = "util/FileUtil/test_du"
472         fileutil.make_dirs(basedir)
473         d = os.path.join(basedir, "space-consuming")
474         self.mkdir(d, "a/b")
475         self.touch(d, "a/b/1.txt", data="a"*10)
476         self.touch(d, "a/b/2.txt", data="b"*11)
477         self.mkdir(d, "a/c")
478         self.touch(d, "a/c/1.txt", data="c"*12)
479         self.touch(d, "a/c/2.txt", data="d"*13)
480
481         used = fileutil.du(basedir)
482         self.failUnlessEqual(10+11+12+13, used)
483
484     def test_abspath_expanduser_unicode(self):
485         self.failUnlessRaises(AssertionError, fileutil.abspath_expanduser_unicode, "bytestring")
486
487         saved_cwd = os.path.normpath(os.getcwdu())
488         abspath_cwd = fileutil.abspath_expanduser_unicode(u".")
489         self.failUnless(isinstance(saved_cwd, unicode), saved_cwd)
490         self.failUnless(isinstance(abspath_cwd, unicode), abspath_cwd)
491         self.failUnlessEqual(abspath_cwd, saved_cwd)
492
493         # adapted from <http://svn.python.org/view/python/branches/release26-maint/Lib/test/test_posixpath.py?view=markup&pathrev=78279#test_abspath>
494
495         self.failUnlessIn(u"foo", fileutil.abspath_expanduser_unicode(u"foo"))
496         self.failIfIn(u"~", fileutil.abspath_expanduser_unicode(u"~"))
497
498         cwds = ['cwd']
499         try:
500             cwds.append(u'\xe7w\xf0'.encode(sys.getfilesystemencoding()
501                                             or 'ascii'))
502         except UnicodeEncodeError:
503             pass # the cwd can't be encoded -- test with ascii cwd only
504
505         for cwd in cwds:
506             try:
507                 os.mkdir(cwd)
508                 os.chdir(cwd)
509                 for upath in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\', u'~'):
510                     uabspath = fileutil.abspath_expanduser_unicode(upath)
511                     self.failUnless(isinstance(uabspath, unicode), uabspath)
512             finally:
513                 os.chdir(saved_cwd)
514
515     def test_disk_stats(self):
516         avail = fileutil.get_available_space('.', 2**14)
517         if avail == 0:
518             raise unittest.SkipTest("This test will spuriously fail there is no disk space left.")
519
520         disk = fileutil.get_disk_stats('.', 2**13)
521         self.failUnless(disk['total'] > 0, disk['total'])
522         self.failUnless(disk['used'] > 0, disk['used'])
523         self.failUnless(disk['free_for_root'] > 0, disk['free_for_root'])
524         self.failUnless(disk['free_for_nonroot'] > 0, disk['free_for_nonroot'])
525         self.failUnless(disk['avail'] > 0, disk['avail'])
526
527     def test_disk_stats_avail_nonnegative(self):
528         # This test will spuriously fail if you have more than 2^128
529         # bytes of available space on your filesystem.
530         disk = fileutil.get_disk_stats('.', 2**128)
531         self.failUnlessEqual(disk['avail'], 0)
532
533 class PollMixinTests(unittest.TestCase):
534     def setUp(self):
535         self.pm = pollmixin.PollMixin()
536
537     def test_PollMixin_True(self):
538         d = self.pm.poll(check_f=lambda : True,
539                          pollinterval=0.1)
540         return d
541
542     def test_PollMixin_False_then_True(self):
543         i = iter([False, True])
544         d = self.pm.poll(check_f=i.next,
545                          pollinterval=0.1)
546         return d
547
548     def test_timeout(self):
549         d = self.pm.poll(check_f=lambda: False,
550                          pollinterval=0.01,
551                          timeout=1)
552         def _suc(res):
553             self.fail("poll should have failed, not returned %s" % (res,))
554         def _err(f):
555             f.trap(pollmixin.TimeoutError)
556             return None # success
557         d.addCallbacks(_suc, _err)
558         return d
559
560 class DeferredUtilTests(unittest.TestCase):
561     def test_gather_results(self):
562         d1 = defer.Deferred()
563         d2 = defer.Deferred()
564         res = deferredutil.gatherResults([d1, d2])
565         d1.errback(ValueError("BAD"))
566         def _callb(res):
567             self.fail("Should have errbacked, not resulted in %s" % (res,))
568         def _errb(thef):
569             thef.trap(ValueError)
570         res.addCallbacks(_callb, _errb)
571         return res
572
573     def test_success(self):
574         d1, d2 = defer.Deferred(), defer.Deferred()
575         good = []
576         bad = []
577         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
578         dlss.addCallbacks(good.append, bad.append)
579         d1.callback(1)
580         d2.callback(2)
581         self.failUnlessEqual(good, [[1,2]])
582         self.failUnlessEqual(bad, [])
583
584     def test_failure(self):
585         d1, d2 = defer.Deferred(), defer.Deferred()
586         good = []
587         bad = []
588         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
589         dlss.addCallbacks(good.append, bad.append)
590         d1.addErrback(lambda _ignore: None)
591         d2.addErrback(lambda _ignore: None)
592         d1.callback(1)
593         d2.errback(ValueError())
594         self.failUnlessEqual(good, [])
595         self.failUnlessEqual(len(bad), 1)
596         f = bad[0]
597         self.failUnless(isinstance(f, Failure))
598         self.failUnless(f.check(ValueError))
599
600 class HashUtilTests(unittest.TestCase):
601
602     def test_random_key(self):
603         k = hashutil.random_key()
604         self.failUnlessEqual(len(k), hashutil.KEYLEN)
605
606     def test_sha256d(self):
607         h1 = hashutil.tagged_hash("tag1", "value")
608         h2 = hashutil.tagged_hasher("tag1")
609         h2.update("value")
610         h2a = h2.digest()
611         h2b = h2.digest()
612         self.failUnlessEqual(h1, h2a)
613         self.failUnlessEqual(h2a, h2b)
614
615     def test_sha256d_truncated(self):
616         h1 = hashutil.tagged_hash("tag1", "value", 16)
617         h2 = hashutil.tagged_hasher("tag1", 16)
618         h2.update("value")
619         h2 = h2.digest()
620         self.failUnlessEqual(len(h1), 16)
621         self.failUnlessEqual(len(h2), 16)
622         self.failUnlessEqual(h1, h2)
623
624     def test_chk(self):
625         h1 = hashutil.convergence_hash(3, 10, 1000, "data", "secret")
626         h2 = hashutil.convergence_hasher(3, 10, 1000, "secret")
627         h2.update("data")
628         h2 = h2.digest()
629         self.failUnlessEqual(h1, h2)
630
631     def test_hashers(self):
632         h1 = hashutil.block_hash("foo")
633         h2 = hashutil.block_hasher()
634         h2.update("foo")
635         self.failUnlessEqual(h1, h2.digest())
636
637         h1 = hashutil.uri_extension_hash("foo")
638         h2 = hashutil.uri_extension_hasher()
639         h2.update("foo")
640         self.failUnlessEqual(h1, h2.digest())
641
642         h1 = hashutil.plaintext_hash("foo")
643         h2 = hashutil.plaintext_hasher()
644         h2.update("foo")
645         self.failUnlessEqual(h1, h2.digest())
646
647         h1 = hashutil.crypttext_hash("foo")
648         h2 = hashutil.crypttext_hasher()
649         h2.update("foo")
650         self.failUnlessEqual(h1, h2.digest())
651
652         h1 = hashutil.crypttext_segment_hash("foo")
653         h2 = hashutil.crypttext_segment_hasher()
654         h2.update("foo")
655         self.failUnlessEqual(h1, h2.digest())
656
657         h1 = hashutil.plaintext_segment_hash("foo")
658         h2 = hashutil.plaintext_segment_hasher()
659         h2.update("foo")
660         self.failUnlessEqual(h1, h2.digest())
661
662     def test_timing_safe_compare(self):
663         self.failUnless(hashutil.timing_safe_compare("a", "a"))
664         self.failUnless(hashutil.timing_safe_compare("ab", "ab"))
665         self.failIf(hashutil.timing_safe_compare("a", "b"))
666         self.failIf(hashutil.timing_safe_compare("a", "aa"))
667
668     def _testknown(self, hashf, expected_a, *args):
669         got = hashf(*args)
670         got_a = base32.b2a(got)
671         self.failUnlessEqual(got_a, expected_a)
672
673     def test_known_answers(self):
674         # assert backwards compatibility
675         self._testknown(hashutil.storage_index_hash, "qb5igbhcc5esa6lwqorsy7e6am", "")
676         self._testknown(hashutil.block_hash, "msjr5bh4evuh7fa3zw7uovixfbvlnstr5b65mrerwfnvjxig2jvq", "")
677         self._testknown(hashutil.uri_extension_hash, "wthsu45q7zewac2mnivoaa4ulh5xvbzdmsbuyztq2a5fzxdrnkka", "")
678         self._testknown(hashutil.plaintext_hash, "5lz5hwz3qj3af7n6e3arblw7xzutvnd3p3fjsngqjcb7utf3x3da", "")
679         self._testknown(hashutil.crypttext_hash, "itdj6e4njtkoiavlrmxkvpreosscssklunhwtvxn6ggho4rkqwga", "")
680         self._testknown(hashutil.crypttext_segment_hash, "aovy5aa7jej6ym5ikgwyoi4pxawnoj3wtaludjz7e2nb5xijb7aa", "")
681         self._testknown(hashutil.plaintext_segment_hash, "4fdgf6qruaisyukhqcmoth4t3li6bkolbxvjy4awwcpprdtva7za", "")
682         self._testknown(hashutil.convergence_hash, "3mo6ni7xweplycin6nowynw2we", 3, 10, 100, "", "converge")
683         self._testknown(hashutil.my_renewal_secret_hash, "ujhr5k5f7ypkp67jkpx6jl4p47pyta7hu5m527cpcgvkafsefm6q", "")
684         self._testknown(hashutil.my_cancel_secret_hash, "rjwzmafe2duixvqy6h47f5wfrokdziry6zhx4smew4cj6iocsfaa", "")
685         self._testknown(hashutil.file_renewal_secret_hash, "hzshk2kf33gzbd5n3a6eszkf6q6o6kixmnag25pniusyaulqjnia", "", "si")
686         self._testknown(hashutil.file_cancel_secret_hash, "bfciwvr6w7wcavsngxzxsxxaszj72dej54n4tu2idzp6b74g255q", "", "si")
687         self._testknown(hashutil.bucket_renewal_secret_hash, "e7imrzgzaoashsncacvy3oysdd2m5yvtooo4gmj4mjlopsazmvuq", "", "\x00"*20)
688         self._testknown(hashutil.bucket_cancel_secret_hash, "dvdujeyxeirj6uux6g7xcf4lvesk632aulwkzjar7srildvtqwma", "", "\x00"*20)
689         self._testknown(hashutil.hmac, "c54ypfi6pevb3nvo6ba42jtglpkry2kbdopqsi7dgrm4r7tw5sra", "tag", "")
690         self._testknown(hashutil.mutable_rwcap_key_hash, "6rvn2iqrghii5n4jbbwwqqsnqu", "iv", "wk")
691         self._testknown(hashutil.ssk_writekey_hash, "ykpgmdbpgbb6yqz5oluw2q26ye", "")
692         self._testknown(hashutil.ssk_write_enabler_master_hash, "izbfbfkoait4dummruol3gy2bnixrrrslgye6ycmkuyujnenzpia", "")
693         self._testknown(hashutil.ssk_write_enabler_hash, "fuu2dvx7g6gqu5x22vfhtyed7p4pd47y5hgxbqzgrlyvxoev62tq", "wk", "\x00"*20)
694         self._testknown(hashutil.ssk_pubkey_fingerprint_hash, "3opzw4hhm2sgncjx224qmt5ipqgagn7h5zivnfzqycvgqgmgz35q", "")
695         self._testknown(hashutil.ssk_readkey_hash, "vugid4as6qbqgeq2xczvvcedai", "")
696         self._testknown(hashutil.ssk_readkey_data_hash, "73wsaldnvdzqaf7v4pzbr2ae5a", "iv", "rk")
697         self._testknown(hashutil.ssk_storage_index_hash, "j7icz6kigb6hxrej3tv4z7ayym", "")
698
699
700 class Abbreviate(unittest.TestCase):
701     def test_time(self):
702         a = abbreviate.abbreviate_time
703         self.failUnlessEqual(a(None), "unknown")
704         self.failUnlessEqual(a(0), "0 seconds")
705         self.failUnlessEqual(a(1), "1 second")
706         self.failUnlessEqual(a(2), "2 seconds")
707         self.failUnlessEqual(a(119), "119 seconds")
708         MIN = 60
709         self.failUnlessEqual(a(2*MIN), "2 minutes")
710         self.failUnlessEqual(a(60*MIN), "60 minutes")
711         self.failUnlessEqual(a(179*MIN), "179 minutes")
712         HOUR = 60*MIN
713         self.failUnlessEqual(a(180*MIN), "3 hours")
714         self.failUnlessEqual(a(4*HOUR), "4 hours")
715         DAY = 24*HOUR
716         MONTH = 30*DAY
717         self.failUnlessEqual(a(2*DAY), "2 days")
718         self.failUnlessEqual(a(2*MONTH), "2 months")
719         YEAR = 365*DAY
720         self.failUnlessEqual(a(5*YEAR), "5 years")
721
722     def test_space(self):
723         tests_si = [(None, "unknown"),
724                     (0, "0 B"),
725                     (1, "1 B"),
726                     (999, "999 B"),
727                     (1000, "1000 B"),
728                     (1023, "1023 B"),
729                     (1024, "1.02 kB"),
730                     (20*1000, "20.00 kB"),
731                     (1024*1024, "1.05 MB"),
732                     (1000*1000, "1.00 MB"),
733                     (1000*1000*1000, "1.00 GB"),
734                     (1000*1000*1000*1000, "1.00 TB"),
735                     (1000*1000*1000*1000*1000, "1.00 PB"),
736                     (1000*1000*1000*1000*1000*1000, "1.00 EB"),
737                     (1234567890123456789, "1.23 EB"),
738                     ]
739         for (x, expected) in tests_si:
740             got = abbreviate.abbreviate_space(x, SI=True)
741             self.failUnlessEqual(got, expected)
742
743         tests_base1024 = [(None, "unknown"),
744                           (0, "0 B"),
745                           (1, "1 B"),
746                           (999, "999 B"),
747                           (1000, "1000 B"),
748                           (1023, "1023 B"),
749                           (1024, "1.00 kiB"),
750                           (20*1024, "20.00 kiB"),
751                           (1000*1000, "976.56 kiB"),
752                           (1024*1024, "1.00 MiB"),
753                           (1024*1024*1024, "1.00 GiB"),
754                           (1024*1024*1024*1024, "1.00 TiB"),
755                           (1000*1000*1000*1000*1000, "909.49 TiB"),
756                           (1024*1024*1024*1024*1024, "1.00 PiB"),
757                           (1024*1024*1024*1024*1024*1024, "1.00 EiB"),
758                           (1234567890123456789, "1.07 EiB"),
759                     ]
760         for (x, expected) in tests_base1024:
761             got = abbreviate.abbreviate_space(x, SI=False)
762             self.failUnlessEqual(got, expected)
763
764         self.failUnlessEqual(abbreviate.abbreviate_space_both(1234567),
765                              "(1.23 MB, 1.18 MiB)")
766
767     def test_parse_space(self):
768         p = abbreviate.parse_abbreviated_size
769         self.failUnlessEqual(p(""), None)
770         self.failUnlessEqual(p(None), None)
771         self.failUnlessEqual(p("123"), 123)
772         self.failUnlessEqual(p("123B"), 123)
773         self.failUnlessEqual(p("2K"), 2000)
774         self.failUnlessEqual(p("2kb"), 2000)
775         self.failUnlessEqual(p("2KiB"), 2048)
776         self.failUnlessEqual(p("10MB"), 10*1000*1000)
777         self.failUnlessEqual(p("10MiB"), 10*1024*1024)
778         self.failUnlessEqual(p("5G"), 5*1000*1000*1000)
779         self.failUnlessEqual(p("4GiB"), 4*1024*1024*1024)
780         self.failUnlessEqual(p("3TB"), 3*1000*1000*1000*1000)
781         self.failUnlessEqual(p("3TiB"), 3*1024*1024*1024*1024)
782         self.failUnlessEqual(p("6PB"), 6*1000*1000*1000*1000*1000)
783         self.failUnlessEqual(p("6PiB"), 6*1024*1024*1024*1024*1024)
784         self.failUnlessEqual(p("9EB"), 9*1000*1000*1000*1000*1000*1000)
785         self.failUnlessEqual(p("9EiB"), 9*1024*1024*1024*1024*1024*1024)
786
787         e = self.failUnlessRaises(ValueError, p, "12 cubits")
788         self.failUnlessIn("12 cubits", str(e))
789         e = self.failUnlessRaises(ValueError, p, "1 BB")
790         self.failUnlessIn("1 BB", str(e))
791         e = self.failUnlessRaises(ValueError, p, "fhtagn")
792         self.failUnlessIn("fhtagn", str(e))
793
794 class Limiter(unittest.TestCase):
795     timeout = 480 # This takes longer than 240 seconds on Francois's arm box.
796
797     def job(self, i, foo):
798         self.calls.append( (i, foo) )
799         self.simultaneous += 1
800         self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
801         d = defer.Deferred()
802         def _done():
803             self.simultaneous -= 1
804             d.callback("done %d" % i)
805         reactor.callLater(1.0, _done)
806         return d
807
808     def bad_job(self, i, foo):
809         raise ValueError("bad_job %d" % i)
810
811     def test_limiter(self):
812         self.calls = []
813         self.simultaneous = 0
814         self.peak_simultaneous = 0
815         l = limiter.ConcurrencyLimiter()
816         dl = []
817         for i in range(20):
818             dl.append(l.add(self.job, i, foo=str(i)))
819         d = defer.DeferredList(dl, fireOnOneErrback=True)
820         def _done(res):
821             self.failUnlessEqual(self.simultaneous, 0)
822             self.failUnless(self.peak_simultaneous <= 10)
823             self.failUnlessEqual(len(self.calls), 20)
824             for i in range(20):
825                 self.failUnless( (i, str(i)) in self.calls)
826         d.addCallback(_done)
827         return d
828
829     def test_errors(self):
830         self.calls = []
831         self.simultaneous = 0
832         self.peak_simultaneous = 0
833         l = limiter.ConcurrencyLimiter()
834         dl = []
835         for i in range(20):
836             dl.append(l.add(self.job, i, foo=str(i)))
837         d2 = l.add(self.bad_job, 21, "21")
838         d = defer.DeferredList(dl, fireOnOneErrback=True)
839         def _most_done(res):
840             results = []
841             for (success, result) in res:
842                 self.failUnlessEqual(success, True)
843                 results.append(result)
844             results.sort()
845             expected_results = ["done %d" % i for i in range(20)]
846             expected_results.sort()
847             self.failUnlessEqual(results, expected_results)
848             self.failUnless(self.peak_simultaneous <= 10)
849             self.failUnlessEqual(len(self.calls), 20)
850             for i in range(20):
851                 self.failUnless( (i, str(i)) in self.calls)
852             def _good(res):
853                 self.fail("should have failed, not got %s" % (res,))
854             def _err(f):
855                 f.trap(ValueError)
856                 self.failUnless("bad_job 21" in str(f))
857             d2.addCallbacks(_good, _err)
858             return d2
859         d.addCallback(_most_done)
860         def _all_done(res):
861             self.failUnlessEqual(self.simultaneous, 0)
862             self.failUnless(self.peak_simultaneous <= 10)
863             self.failUnlessEqual(len(self.calls), 20)
864             for i in range(20):
865                 self.failUnless( (i, str(i)) in self.calls)
866         d.addCallback(_all_done)
867         return d
868
869 class TimeFormat(unittest.TestCase):
870     def test_epoch(self):
871         return self._help_test_epoch()
872
873     def test_epoch_in_London(self):
874         # Europe/London is a particularly troublesome timezone.  Nowadays, its
875         # offset from GMT is 0.  But in 1970, its offset from GMT was 1.
876         # (Apparently in 1970 Britain had redefined standard time to be GMT+1
877         # and stayed in standard time all year round, whereas today
878         # Europe/London standard time is GMT and Europe/London Daylight
879         # Savings Time is GMT+1.)  The current implementation of
880         # time_format.iso_utc_time_to_localseconds() breaks if the timezone is
881         # Europe/London.  (As soon as this unit test is done then I'll change
882         # that implementation to something that works even in this case...)
883         origtz = os.environ.get('TZ')
884         os.environ['TZ'] = "Europe/London"
885         if hasattr(time, 'tzset'):
886             time.tzset()
887         try:
888             return self._help_test_epoch()
889         finally:
890             if origtz is None:
891                 del os.environ['TZ']
892             else:
893                 os.environ['TZ'] = origtz
894             if hasattr(time, 'tzset'):
895                 time.tzset()
896
897     def _help_test_epoch(self):
898         origtzname = time.tzname
899         s = time_format.iso_utc_time_to_seconds("1970-01-01T00:00:01")
900         self.failUnlessEqual(s, 1.0)
901         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01")
902         self.failUnlessEqual(s, 1.0)
903         s = time_format.iso_utc_time_to_seconds("1970-01-01 00:00:01")
904         self.failUnlessEqual(s, 1.0)
905
906         self.failUnlessEqual(time_format.iso_utc(1.0), "1970-01-01_00:00:01")
907         self.failUnlessEqual(time_format.iso_utc(1.0, sep=" "),
908                              "1970-01-01 00:00:01")
909
910         now = time.time()
911         isostr = time_format.iso_utc(now)
912         timestamp = time_format.iso_utc_time_to_seconds(isostr)
913         self.failUnlessEqual(int(timestamp), int(now))
914
915         def my_time():
916             return 1.0
917         self.failUnlessEqual(time_format.iso_utc(t=my_time),
918                              "1970-01-01_00:00:01")
919         e = self.failUnlessRaises(ValueError,
920                                   time_format.iso_utc_time_to_seconds,
921                                   "invalid timestring")
922         self.failUnless("not a complete ISO8601 timestamp" in str(e))
923         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01.500")
924         self.failUnlessEqual(s, 1.5)
925
926         # Look for daylight-savings-related errors.
927         thatmomentinmarch = time_format.iso_utc_time_to_seconds("2009-03-20 21:49:02.226536")
928         self.failUnlessEqual(thatmomentinmarch, 1237585742.226536)
929         self.failUnlessEqual(origtzname, time.tzname)
930
931     def test_iso_utc(self):
932         when = 1266760143.7841301
933         out = time_format.iso_utc_date(when)
934         self.failUnlessEqual(out, "2010-02-21")
935         out = time_format.iso_utc_date(t=lambda: when)
936         self.failUnlessEqual(out, "2010-02-21")
937         out = time_format.iso_utc(when)
938         self.failUnlessEqual(out, "2010-02-21_13:49:03.784130")
939         out = time_format.iso_utc(when, sep="-")
940         self.failUnlessEqual(out, "2010-02-21-13:49:03.784130")
941
942     def test_parse_duration(self):
943         p = time_format.parse_duration
944         DAY = 24*60*60
945         self.failUnlessEqual(p("1 day"), DAY)
946         self.failUnlessEqual(p("2 days"), 2*DAY)
947         self.failUnlessEqual(p("3 months"), 3*31*DAY)
948         self.failUnlessEqual(p("4 mo"), 4*31*DAY)
949         self.failUnlessEqual(p("5 years"), 5*365*DAY)
950         e = self.failUnlessRaises(ValueError, p, "123")
951         self.failUnlessIn("no unit (like day, month, or year) in '123'",
952                           str(e))
953
954     def test_parse_date(self):
955         self.failUnlessEqual(time_format.parse_date("2010-02-21"), 1266710400)
956
957 class CacheDir(unittest.TestCase):
958     def test_basic(self):
959         basedir = "test_util/CacheDir/test_basic"
960
961         def _failIfExists(name):
962             absfn = os.path.join(basedir, name)
963             self.failIf(os.path.exists(absfn),
964                         "%s exists but it shouldn't" % absfn)
965
966         def _failUnlessExists(name):
967             absfn = os.path.join(basedir, name)
968             self.failUnless(os.path.exists(absfn),
969                             "%s doesn't exist but it should" % absfn)
970
971         cdm = cachedir.CacheDirectoryManager(basedir)
972         a = cdm.get_file("a")
973         b = cdm.get_file("b")
974         c = cdm.get_file("c")
975         f = open(a.get_filename(), "wb"); f.write("hi"); f.close(); del f
976         f = open(b.get_filename(), "wb"); f.write("hi"); f.close(); del f
977         f = open(c.get_filename(), "wb"); f.write("hi"); f.close(); del f
978
979         _failUnlessExists("a")
980         _failUnlessExists("b")
981         _failUnlessExists("c")
982
983         cdm.check()
984
985         _failUnlessExists("a")
986         _failUnlessExists("b")
987         _failUnlessExists("c")
988
989         del a
990         # this file won't be deleted yet, because it isn't old enough
991         cdm.check()
992         _failUnlessExists("a")
993         _failUnlessExists("b")
994         _failUnlessExists("c")
995
996         # we change the definition of "old" to make everything old
997         cdm.old = -10
998
999         cdm.check()
1000         _failIfExists("a")
1001         _failUnlessExists("b")
1002         _failUnlessExists("c")
1003
1004         cdm.old = 60*60
1005
1006         del b
1007
1008         cdm.check()
1009         _failIfExists("a")
1010         _failUnlessExists("b")
1011         _failUnlessExists("c")
1012
1013         b2 = cdm.get_file("b")
1014
1015         cdm.check()
1016         _failIfExists("a")
1017         _failUnlessExists("b")
1018         _failUnlessExists("c")
1019         del b2
1020
1021 ctr = [0]
1022 class EqButNotIs:
1023     def __init__(self, x):
1024         self.x = x
1025         self.hash = ctr[0]
1026         ctr[0] += 1
1027     def __repr__(self):
1028         return "<%s %s>" % (self.__class__.__name__, self.x,)
1029     def __hash__(self):
1030         return self.hash
1031     def __le__(self, other):
1032         return self.x <= other
1033     def __lt__(self, other):
1034         return self.x < other
1035     def __ge__(self, other):
1036         return self.x >= other
1037     def __gt__(self, other):
1038         return self.x > other
1039     def __ne__(self, other):
1040         return self.x != other
1041     def __eq__(self, other):
1042         return self.x == other
1043
1044 class DictUtil(unittest.TestCase):
1045     def _help_test_empty_dict(self, klass):
1046         d1 = klass()
1047         d2 = klass({})
1048
1049         self.failUnless(d1 == d2, "d1: %r, d2: %r" % (d1, d2,))
1050         self.failUnless(len(d1) == 0)
1051         self.failUnless(len(d2) == 0)
1052
1053     def _help_test_nonempty_dict(self, klass):
1054         d1 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1055         d2 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1056
1057         self.failUnless(d1 == d2)
1058         self.failUnless(len(d1) == 3, "%s, %s" % (len(d1), d1,))
1059         self.failUnless(len(d2) == 3)
1060
1061     def _help_test_eq_but_notis(self, klass):
1062         d = klass({'a': 3, 'b': EqButNotIs(3), 'c': 3})
1063         d.pop('b')
1064
1065         d.clear()
1066         d['a'] = 3
1067         d['b'] = EqButNotIs(3)
1068         d['c'] = 3
1069         d.pop('b')
1070
1071         d.clear()
1072         d['b'] = EqButNotIs(3)
1073         d['a'] = 3
1074         d['c'] = 3
1075         d.pop('b')
1076
1077         d.clear()
1078         d['a'] = EqButNotIs(3)
1079         d['c'] = 3
1080         d['a'] = 3
1081
1082         d.clear()
1083         fake3 = EqButNotIs(3)
1084         fake7 = EqButNotIs(7)
1085         d[fake3] = fake7
1086         d[3] = 7
1087         d[3] = 8
1088         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1089         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1090         # The real 7 should have been ejected by the d[3] = 8.
1091         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1092         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1093         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1094         d[fake3] = 8
1095
1096         d.clear()
1097         d[3] = 7
1098         fake3 = EqButNotIs(3)
1099         fake7 = EqButNotIs(7)
1100         d[fake3] = fake7
1101         d[3] = 8
1102         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1103         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1104         # The real 7 should have been ejected by the d[3] = 8.
1105         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1106         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1107         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1108         d[fake3] = 8
1109
1110     def test_all(self):
1111         self._help_test_eq_but_notis(dictutil.UtilDict)
1112         self._help_test_eq_but_notis(dictutil.NumDict)
1113         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1114         self._help_test_nonempty_dict(dictutil.UtilDict)
1115         self._help_test_nonempty_dict(dictutil.NumDict)
1116         self._help_test_nonempty_dict(dictutil.ValueOrderedDict)
1117         self._help_test_eq_but_notis(dictutil.UtilDict)
1118         self._help_test_eq_but_notis(dictutil.NumDict)
1119         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1120
1121     def test_dict_of_sets(self):
1122         ds = dictutil.DictOfSets()
1123         ds.add(1, "a")
1124         ds.add(2, "b")
1125         ds.add(2, "b")
1126         ds.add(2, "c")
1127         self.failUnlessEqual(ds[1], set(["a"]))
1128         self.failUnlessEqual(ds[2], set(["b", "c"]))
1129         ds.discard(3, "d") # should not raise an exception
1130         ds.discard(2, "b")
1131         self.failUnlessEqual(ds[2], set(["c"]))
1132         ds.discard(2, "c")
1133         self.failIf(2 in ds)
1134
1135         ds.add(3, "f")
1136         ds2 = dictutil.DictOfSets()
1137         ds2.add(3, "f")
1138         ds2.add(3, "g")
1139         ds2.add(4, "h")
1140         ds.update(ds2)
1141         self.failUnlessEqual(ds[1], set(["a"]))
1142         self.failUnlessEqual(ds[3], set(["f", "g"]))
1143         self.failUnlessEqual(ds[4], set(["h"]))
1144
1145     def test_move(self):
1146         d1 = {1: "a", 2: "b"}
1147         d2 = {2: "c", 3: "d"}
1148         dictutil.move(1, d1, d2)
1149         self.failUnlessEqual(d1, {2: "b"})
1150         self.failUnlessEqual(d2, {1: "a", 2: "c", 3: "d"})
1151
1152         d1 = {1: "a", 2: "b"}
1153         d2 = {2: "c", 3: "d"}
1154         dictutil.move(2, d1, d2)
1155         self.failUnlessEqual(d1, {1: "a"})
1156         self.failUnlessEqual(d2, {2: "b", 3: "d"})
1157
1158         d1 = {1: "a", 2: "b"}
1159         d2 = {2: "c", 3: "d"}
1160         self.failUnlessRaises(KeyError, dictutil.move, 5, d1, d2, strict=True)
1161
1162     def test_subtract(self):
1163         d1 = {1: "a", 2: "b"}
1164         d2 = {2: "c", 3: "d"}
1165         d3 = dictutil.subtract(d1, d2)
1166         self.failUnlessEqual(d3, {1: "a"})
1167
1168         d1 = {1: "a", 2: "b"}
1169         d2 = {2: "c"}
1170         d3 = dictutil.subtract(d1, d2)
1171         self.failUnlessEqual(d3, {1: "a"})
1172
1173     def test_utildict(self):
1174         d = dictutil.UtilDict({1: "a", 2: "b"})
1175         d.del_if_present(1)
1176         d.del_if_present(3)
1177         self.failUnlessEqual(d, {2: "b"})
1178         def eq(a, b):
1179             return a == b
1180         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1181
1182         d = dictutil.UtilDict({1: "b", 2: "a"})
1183         self.failUnlessEqual(d.items_sorted_by_value(),
1184                              [(2, "a"), (1, "b")])
1185         self.failUnlessEqual(d.items_sorted_by_key(),
1186                              [(1, "b"), (2, "a")])
1187         self.failUnlessEqual(repr(d), "{1: 'b', 2: 'a'}")
1188         self.failUnless(1 in d)
1189
1190         d2 = dictutil.UtilDict({3: "c", 4: "d"})
1191         self.failUnless(d != d2)
1192         self.failUnless(d2 > d)
1193         self.failUnless(d2 >= d)
1194         self.failUnless(d <= d2)
1195         self.failUnless(d < d2)
1196         self.failUnlessEqual(d[1], "b")
1197         self.failUnlessEqual(sorted(list([k for k in d])), [1,2])
1198
1199         d3 = d.copy()
1200         self.failUnlessEqual(d, d3)
1201         self.failUnless(isinstance(d3, dictutil.UtilDict))
1202
1203         d4 = d.fromkeys([3,4], "e")
1204         self.failUnlessEqual(d4, {3: "e", 4: "e"})
1205
1206         self.failUnlessEqual(d.get(1), "b")
1207         self.failUnlessEqual(d.get(3), None)
1208         self.failUnlessEqual(d.get(3, "default"), "default")
1209         self.failUnlessEqual(sorted(list(d.items())),
1210                              [(1, "b"), (2, "a")])
1211         self.failUnlessEqual(sorted(list(d.iteritems())),
1212                              [(1, "b"), (2, "a")])
1213         self.failUnlessEqual(sorted(d.keys()), [1, 2])
1214         self.failUnlessEqual(sorted(d.values()), ["a", "b"])
1215         x = d.setdefault(1, "new")
1216         self.failUnlessEqual(x, "b")
1217         self.failUnlessEqual(d[1], "b")
1218         x = d.setdefault(3, "new")
1219         self.failUnlessEqual(x, "new")
1220         self.failUnlessEqual(d[3], "new")
1221         del d[3]
1222
1223         x = d.popitem()
1224         self.failUnless(x in [(1, "b"), (2, "a")])
1225         x = d.popitem()
1226         self.failUnless(x in [(1, "b"), (2, "a")])
1227         self.failUnlessRaises(KeyError, d.popitem)
1228
1229     def test_numdict(self):
1230         d = dictutil.NumDict({"a": 1, "b": 2})
1231
1232         d.add_num("a", 10, 5)
1233         d.add_num("c", 20, 5)
1234         d.add_num("d", 30)
1235         self.failUnlessEqual(d, {"a": 11, "b": 2, "c": 25, "d": 30})
1236
1237         d.subtract_num("a", 10)
1238         d.subtract_num("e", 10)
1239         d.subtract_num("f", 10, 15)
1240         self.failUnlessEqual(d, {"a": 1, "b": 2, "c": 25, "d": 30,
1241                                  "e": -10, "f": 5})
1242
1243         self.failUnlessEqual(d.sum(), sum([1, 2, 25, 30, -10, 5]))
1244
1245         d = dictutil.NumDict()
1246         d.inc("a")
1247         d.inc("a")
1248         d.inc("b", 5)
1249         self.failUnlessEqual(d, {"a": 2, "b": 6})
1250         d.dec("a")
1251         d.dec("c")
1252         d.dec("d", 5)
1253         self.failUnlessEqual(d, {"a": 1, "b": 6, "c": -1, "d": 4})
1254         self.failUnlessEqual(d.items_sorted_by_key(),
1255                              [("a", 1), ("b", 6), ("c", -1), ("d", 4)])
1256         self.failUnlessEqual(d.items_sorted_by_value(),
1257                              [("c", -1), ("a", 1), ("d", 4), ("b", 6)])
1258         self.failUnlessEqual(d.item_with_largest_value(), ("b", 6))
1259
1260         d = dictutil.NumDict({"a": 1, "b": 2})
1261         self.failUnlessEqual(repr(d), "{'a': 1, 'b': 2}")
1262         self.failUnless("a" in d)
1263
1264         d2 = dictutil.NumDict({"c": 3, "d": 4})
1265         self.failUnless(d != d2)
1266         self.failUnless(d2 > d)
1267         self.failUnless(d2 >= d)
1268         self.failUnless(d <= d2)
1269         self.failUnless(d < d2)
1270         self.failUnlessEqual(d["a"], 1)
1271         self.failUnlessEqual(sorted(list([k for k in d])), ["a","b"])
1272         def eq(a, b):
1273             return a == b
1274         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1275
1276         d3 = d.copy()
1277         self.failUnlessEqual(d, d3)
1278         self.failUnless(isinstance(d3, dictutil.NumDict))
1279
1280         d4 = d.fromkeys(["a","b"], 5)
1281         self.failUnlessEqual(d4, {"a": 5, "b": 5})
1282
1283         self.failUnlessEqual(d.get("a"), 1)
1284         self.failUnlessEqual(d.get("c"), 0)
1285         self.failUnlessEqual(d.get("c", 5), 5)
1286         self.failUnlessEqual(sorted(list(d.items())),
1287                              [("a", 1), ("b", 2)])
1288         self.failUnlessEqual(sorted(list(d.iteritems())),
1289                              [("a", 1), ("b", 2)])
1290         self.failUnlessEqual(sorted(d.keys()), ["a", "b"])
1291         self.failUnlessEqual(sorted(d.values()), [1, 2])
1292         self.failUnless(d.has_key("a"))
1293         self.failIf(d.has_key("c"))
1294
1295         x = d.setdefault("c", 3)
1296         self.failUnlessEqual(x, 3)
1297         self.failUnlessEqual(d["c"], 3)
1298         x = d.setdefault("c", 5)
1299         self.failUnlessEqual(x, 3)
1300         self.failUnlessEqual(d["c"], 3)
1301         del d["c"]
1302
1303         x = d.popitem()
1304         self.failUnless(x in [("a", 1), ("b", 2)])
1305         x = d.popitem()
1306         self.failUnless(x in [("a", 1), ("b", 2)])
1307         self.failUnlessRaises(KeyError, d.popitem)
1308
1309         d.update({"c": 3})
1310         d.update({"c": 4, "d": 5})
1311         self.failUnlessEqual(d, {"c": 4, "d": 5})
1312
1313     def test_del_if_present(self):
1314         d = {1: "a", 2: "b"}
1315         dictutil.del_if_present(d, 1)
1316         dictutil.del_if_present(d, 3)
1317         self.failUnlessEqual(d, {2: "b"})
1318
1319     def test_valueordereddict(self):
1320         d = dictutil.ValueOrderedDict()
1321         d["a"] = 3
1322         d["b"] = 2
1323         d["c"] = 1
1324
1325         self.failUnlessEqual(d, {"a": 3, "b": 2, "c": 1})
1326         self.failUnlessEqual(d.items(), [("c", 1), ("b", 2), ("a", 3)])
1327         self.failUnlessEqual(d.values(), [1, 2, 3])
1328         self.failUnlessEqual(d.keys(), ["c", "b", "a"])
1329         self.failUnlessEqual(repr(d), "<ValueOrderedDict {c: 1, b: 2, a: 3}>")
1330         def eq(a, b):
1331             return a == b
1332         self.failIf(d == {"a": 4})
1333         self.failUnless(d != {"a": 4})
1334
1335         x = d.setdefault("d", 0)
1336         self.failUnlessEqual(x, 0)
1337         self.failUnlessEqual(d["d"], 0)
1338         x = d.setdefault("d", -1)
1339         self.failUnlessEqual(x, 0)
1340         self.failUnlessEqual(d["d"], 0)
1341
1342         x = d.remove("e", "default", False)
1343         self.failUnlessEqual(x, "default")
1344         self.failUnlessRaises(KeyError, d.remove, "e", "default", True)
1345         x = d.remove("d", 5)
1346         self.failUnlessEqual(x, 0)
1347
1348         x = d.__getitem__("c")
1349         self.failUnlessEqual(x, 1)
1350         x = d.__getitem__("e", "default", False)
1351         self.failUnlessEqual(x, "default")
1352         self.failUnlessRaises(KeyError, d.__getitem__, "e", "default", True)
1353
1354         self.failUnlessEqual(d.popitem(), ("c", 1))
1355         self.failUnlessEqual(d.popitem(), ("b", 2))
1356         self.failUnlessEqual(d.popitem(), ("a", 3))
1357         self.failUnlessRaises(KeyError, d.popitem)
1358
1359         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1360         x = d.pop("d", "default", False)
1361         self.failUnlessEqual(x, "default")
1362         self.failUnlessRaises(KeyError, d.pop, "d", "default", True)
1363         x = d.pop("b")
1364         self.failUnlessEqual(x, 2)
1365         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1366
1367         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1368         x = d.pop_from_list(1) # pop the second item, b/2
1369         self.failUnlessEqual(x, "b")
1370         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1371
1372     def test_auxdict(self):
1373         d = dictutil.AuxValueDict()
1374         # we put the serialized form in the auxdata
1375         d.set_with_aux("key", ("filecap", "metadata"), "serialized")
1376
1377         self.failUnlessEqual(d.keys(), ["key"])
1378         self.failUnlessEqual(d["key"], ("filecap", "metadata"))
1379         self.failUnlessEqual(d.get_aux("key"), "serialized")
1380         def _get_missing(key):
1381             return d[key]
1382         self.failUnlessRaises(KeyError, _get_missing, "nonkey")
1383         self.failUnlessEqual(d.get("nonkey"), None)
1384         self.failUnlessEqual(d.get("nonkey", "nonvalue"), "nonvalue")
1385         self.failUnlessEqual(d.get_aux("nonkey"), None)
1386         self.failUnlessEqual(d.get_aux("nonkey", "nonvalue"), "nonvalue")
1387
1388         d["key"] = ("filecap2", "metadata2")
1389         self.failUnlessEqual(d["key"], ("filecap2", "metadata2"))
1390         self.failUnlessEqual(d.get_aux("key"), None)
1391
1392         d.set_with_aux("key2", "value2", "aux2")
1393         self.failUnlessEqual(sorted(d.keys()), ["key", "key2"])
1394         del d["key2"]
1395         self.failUnlessEqual(d.keys(), ["key"])
1396         self.failIf("key2" in d)
1397         self.failUnlessRaises(KeyError, _get_missing, "key2")
1398         self.failUnlessEqual(d.get("key2"), None)
1399         self.failUnlessEqual(d.get_aux("key2"), None)
1400         d["key2"] = "newvalue2"
1401         self.failUnlessEqual(d.get("key2"), "newvalue2")
1402         self.failUnlessEqual(d.get_aux("key2"), None)
1403
1404         d = dictutil.AuxValueDict({1:2,3:4})
1405         self.failUnlessEqual(sorted(d.keys()), [1,3])
1406         self.failUnlessEqual(d[1], 2)
1407         self.failUnlessEqual(d.get_aux(1), None)
1408
1409         d = dictutil.AuxValueDict([ (1,2), (3,4) ])
1410         self.failUnlessEqual(sorted(d.keys()), [1,3])
1411         self.failUnlessEqual(d[1], 2)
1412         self.failUnlessEqual(d.get_aux(1), None)
1413
1414         d = dictutil.AuxValueDict(one=1, two=2)
1415         self.failUnlessEqual(sorted(d.keys()), ["one","two"])
1416         self.failUnlessEqual(d["one"], 1)
1417         self.failUnlessEqual(d.get_aux("one"), None)
1418
1419 class Pipeline(unittest.TestCase):
1420     def pause(self, *args, **kwargs):
1421         d = defer.Deferred()
1422         self.calls.append( (d, args, kwargs) )
1423         return d
1424
1425     def failUnlessCallsAre(self, expected):
1426         #print self.calls
1427         #print expected
1428         self.failUnlessEqual(len(self.calls), len(expected), self.calls)
1429         for i,c in enumerate(self.calls):
1430             self.failUnlessEqual(c[1:], expected[i], str(i))
1431
1432     def test_basic(self):
1433         self.calls = []
1434         finished = []
1435         p = pipeline.Pipeline(100)
1436
1437         d = p.flush() # fires immediately
1438         d.addCallbacks(finished.append, log.err)
1439         self.failUnlessEqual(len(finished), 1)
1440         finished = []
1441
1442         d = p.add(10, self.pause, "one")
1443         # the call should start right away, and our return Deferred should
1444         # fire right away
1445         d.addCallbacks(finished.append, log.err)
1446         self.failUnlessEqual(len(finished), 1)
1447         self.failUnlessEqual(finished[0], None)
1448         self.failUnlessCallsAre([ ( ("one",) , {} ) ])
1449         self.failUnlessEqual(p.gauge, 10)
1450
1451         # pipeline: [one]
1452
1453         finished = []
1454         d = p.add(20, self.pause, "two", kw=2)
1455         # pipeline: [one, two]
1456
1457         # the call and the Deferred should fire right away
1458         d.addCallbacks(finished.append, log.err)
1459         self.failUnlessEqual(len(finished), 1)
1460         self.failUnlessEqual(finished[0], None)
1461         self.failUnlessCallsAre([ ( ("one",) , {} ),
1462                                   ( ("two",) , {"kw": 2} ),
1463                                   ])
1464         self.failUnlessEqual(p.gauge, 30)
1465
1466         self.calls[0][0].callback("one-result")
1467         # pipeline: [two]
1468         self.failUnlessEqual(p.gauge, 20)
1469
1470         finished = []
1471         d = p.add(90, self.pause, "three", "posarg1")
1472         # pipeline: [two, three]
1473         flushed = []
1474         fd = p.flush()
1475         fd.addCallbacks(flushed.append, log.err)
1476         self.failUnlessEqual(flushed, [])
1477
1478         # the call will be made right away, but the return Deferred will not,
1479         # because the pipeline is now full.
1480         d.addCallbacks(finished.append, log.err)
1481         self.failUnlessEqual(len(finished), 0)
1482         self.failUnlessCallsAre([ ( ("one",) , {} ),
1483                                   ( ("two",) , {"kw": 2} ),
1484                                   ( ("three", "posarg1"), {} ),
1485                                   ])
1486         self.failUnlessEqual(p.gauge, 110)
1487
1488         self.failUnlessRaises(pipeline.SingleFileError, p.add, 10, self.pause)
1489
1490         # retiring either call will unblock the pipeline, causing the #3
1491         # Deferred to fire
1492         self.calls[2][0].callback("three-result")
1493         # pipeline: [two]
1494
1495         self.failUnlessEqual(len(finished), 1)
1496         self.failUnlessEqual(finished[0], None)
1497         self.failUnlessEqual(flushed, [])
1498
1499         # retiring call#2 will finally allow the flush() Deferred to fire
1500         self.calls[1][0].callback("two-result")
1501         self.failUnlessEqual(len(flushed), 1)
1502
1503     def test_errors(self):
1504         self.calls = []
1505         p = pipeline.Pipeline(100)
1506
1507         d1 = p.add(200, self.pause, "one")
1508         d2 = p.flush()
1509
1510         finished = []
1511         d1.addBoth(finished.append)
1512         self.failUnlessEqual(finished, [])
1513
1514         flushed = []
1515         d2.addBoth(flushed.append)
1516         self.failUnlessEqual(flushed, [])
1517
1518         self.calls[0][0].errback(ValueError("oops"))
1519
1520         self.failUnlessEqual(len(finished), 1)
1521         f = finished[0]
1522         self.failUnless(isinstance(f, Failure))
1523         self.failUnless(f.check(pipeline.PipelineError))
1524         self.failUnlessIn("PipelineError", str(f.value))
1525         self.failUnlessIn("ValueError", str(f.value))
1526         r = repr(f.value)
1527         self.failUnless("ValueError" in r, r)
1528         f2 = f.value.error
1529         self.failUnless(f2.check(ValueError))
1530
1531         self.failUnlessEqual(len(flushed), 1)
1532         f = flushed[0]
1533         self.failUnless(isinstance(f, Failure))
1534         self.failUnless(f.check(pipeline.PipelineError))
1535         f2 = f.value.error
1536         self.failUnless(f2.check(ValueError))
1537
1538         # now that the pipeline is in the failed state, any new calls will
1539         # fail immediately
1540
1541         d3 = p.add(20, self.pause, "two")
1542
1543         finished = []
1544         d3.addBoth(finished.append)
1545         self.failUnlessEqual(len(finished), 1)
1546         f = finished[0]
1547         self.failUnless(isinstance(f, Failure))
1548         self.failUnless(f.check(pipeline.PipelineError))
1549         r = repr(f.value)
1550         self.failUnless("ValueError" in r, r)
1551         f2 = f.value.error
1552         self.failUnless(f2.check(ValueError))
1553
1554         d4 = p.flush()
1555         flushed = []
1556         d4.addBoth(flushed.append)
1557         self.failUnlessEqual(len(flushed), 1)
1558         f = flushed[0]
1559         self.failUnless(isinstance(f, Failure))
1560         self.failUnless(f.check(pipeline.PipelineError))
1561         f2 = f.value.error
1562         self.failUnless(f2.check(ValueError))
1563
1564     def test_errors2(self):
1565         self.calls = []
1566         p = pipeline.Pipeline(100)
1567
1568         d1 = p.add(10, self.pause, "one")
1569         d2 = p.add(20, self.pause, "two")
1570         d3 = p.add(30, self.pause, "three")
1571         d4 = p.flush()
1572
1573         # one call fails, then the second one succeeds: make sure
1574         # ExpandableDeferredList tolerates the second one
1575
1576         flushed = []
1577         d4.addBoth(flushed.append)
1578         self.failUnlessEqual(flushed, [])
1579
1580         self.calls[0][0].errback(ValueError("oops"))
1581         self.failUnlessEqual(len(flushed), 1)
1582         f = flushed[0]
1583         self.failUnless(isinstance(f, Failure))
1584         self.failUnless(f.check(pipeline.PipelineError))
1585         f2 = f.value.error
1586         self.failUnless(f2.check(ValueError))
1587
1588         self.calls[1][0].callback("two-result")
1589         self.calls[2][0].errback(ValueError("three-error"))
1590
1591         del d1,d2,d3,d4
1592
1593 class SampleError(Exception):
1594     pass
1595
1596 class Log(unittest.TestCase):
1597     def test_err(self):
1598         if not hasattr(self, "flushLoggedErrors"):
1599             # without flushLoggedErrors, we can't get rid of the
1600             # twisted.log.err that tahoe_log records, so we can't keep this
1601             # test from [ERROR]ing
1602             raise unittest.SkipTest("needs flushLoggedErrors from Twisted-2.5.0")
1603         try:
1604             raise SampleError("simple sample")
1605         except:
1606             f = Failure()
1607         tahoe_log.err(format="intentional sample error",
1608                       failure=f, level=tahoe_log.OPERATIONAL, umid="wO9UoQ")
1609         self.flushLoggedErrors(SampleError)
1610
1611
1612 class SimpleSpans:
1613     # this is a simple+inefficient form of util.spans.Spans . We compare the
1614     # behavior of this reference model against the real (efficient) form.
1615
1616     def __init__(self, _span_or_start=None, length=None):
1617         self._have = set()
1618         if length is not None:
1619             for i in range(_span_or_start, _span_or_start+length):
1620                 self._have.add(i)
1621         elif _span_or_start:
1622             for (start,length) in _span_or_start:
1623                 self.add(start, length)
1624
1625     def add(self, start, length):
1626         for i in range(start, start+length):
1627             self._have.add(i)
1628         return self
1629
1630     def remove(self, start, length):
1631         for i in range(start, start+length):
1632             self._have.discard(i)
1633         return self
1634
1635     def each(self):
1636         return sorted(self._have)
1637
1638     def __iter__(self):
1639         items = sorted(self._have)
1640         prevstart = None
1641         prevend = None
1642         for i in items:
1643             if prevstart is None:
1644                 prevstart = prevend = i
1645                 continue
1646             if i == prevend+1:
1647                 prevend = i
1648                 continue
1649             yield (prevstart, prevend-prevstart+1)
1650             prevstart = prevend = i
1651         if prevstart is not None:
1652             yield (prevstart, prevend-prevstart+1)
1653
1654     def __nonzero__(self): # this gets us bool()
1655         return self.len()
1656
1657     def len(self):
1658         return len(self._have)
1659
1660     def __add__(self, other):
1661         s = self.__class__(self)
1662         for (start, length) in other:
1663             s.add(start, length)
1664         return s
1665
1666     def __sub__(self, other):
1667         s = self.__class__(self)
1668         for (start, length) in other:
1669             s.remove(start, length)
1670         return s
1671
1672     def __iadd__(self, other):
1673         for (start, length) in other:
1674             self.add(start, length)
1675         return self
1676
1677     def __isub__(self, other):
1678         for (start, length) in other:
1679             self.remove(start, length)
1680         return self
1681
1682     def __and__(self, other):
1683         s = self.__class__()
1684         for i in other.each():
1685             if i in self._have:
1686                 s.add(i, 1)
1687         return s
1688
1689     def __contains__(self, (start,length)):
1690         for i in range(start, start+length):
1691             if i not in self._have:
1692                 return False
1693         return True
1694
1695 class ByteSpans(unittest.TestCase):
1696     def test_basic(self):
1697         s = Spans()
1698         self.failUnlessEqual(list(s), [])
1699         self.failIf(s)
1700         self.failIf((0,1) in s)
1701         self.failUnlessEqual(s.len(), 0)
1702
1703         s1 = Spans(3, 4) # 3,4,5,6
1704         self._check1(s1)
1705
1706         s1 = Spans(3L, 4L) # 3,4,5,6
1707         self._check1(s1)
1708
1709         s2 = Spans(s1)
1710         self._check1(s2)
1711
1712         s2.add(10,2) # 10,11
1713         self._check1(s1)
1714         self.failUnless((10,1) in s2)
1715         self.failIf((10,1) in s1)
1716         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
1717         self.failUnlessEqual(s2.len(), 6)
1718
1719         s2.add(15,2).add(20,2)
1720         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
1721         self.failUnlessEqual(s2.len(), 10)
1722
1723         s2.remove(4,3).remove(15,1)
1724         self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
1725         self.failUnlessEqual(s2.len(), 6)
1726
1727         s1 = SimpleSpans(3, 4) # 3 4 5 6
1728         s2 = SimpleSpans(5, 4) # 5 6 7 8
1729         i = s1 & s2
1730         self.failUnlessEqual(list(i.each()), [5, 6])
1731
1732     def _check1(self, s):
1733         self.failUnlessEqual(list(s), [(3,4)])
1734         self.failUnless(s)
1735         self.failUnlessEqual(s.len(), 4)
1736         self.failIf((0,1) in s)
1737         self.failUnless((3,4) in s)
1738         self.failUnless((3,1) in s)
1739         self.failUnless((5,2) in s)
1740         self.failUnless((6,1) in s)
1741         self.failIf((6,2) in s)
1742         self.failIf((7,1) in s)
1743         self.failUnlessEqual(list(s.each()), [3,4,5,6])
1744
1745     def test_large(self):
1746         s = Spans(4, 2**65) # don't do this with a SimpleSpans
1747         self.failUnlessEqual(list(s), [(4, 2**65)])
1748         self.failUnless(s)
1749         self.failUnlessEqual(s.len(), 2**65)
1750         self.failIf((0,1) in s)
1751         self.failUnless((4,2) in s)
1752         self.failUnless((2**65,2) in s)
1753
1754     def test_math(self):
1755         s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
1756         s2 = Spans(5, 3) # 5,6,7
1757         s3 = Spans(8, 4) # 8,9,10,11
1758
1759         s = s1 - s2
1760         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1761         s = s1 - s3
1762         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1763         s = s2 - s3
1764         self.failUnlessEqual(list(s.each()), [5,6,7])
1765         s = s1 & s2
1766         self.failUnlessEqual(list(s.each()), [5,6,7])
1767         s = s2 & s1
1768         self.failUnlessEqual(list(s.each()), [5,6,7])
1769         s = s1 & s3
1770         self.failUnlessEqual(list(s.each()), [8,9])
1771         s = s3 & s1
1772         self.failUnlessEqual(list(s.each()), [8,9])
1773         s = s2 & s3
1774         self.failUnlessEqual(list(s.each()), [])
1775         s = s3 & s2
1776         self.failUnlessEqual(list(s.each()), [])
1777         s = Spans() & s3
1778         self.failUnlessEqual(list(s.each()), [])
1779         s = s3 & Spans()
1780         self.failUnlessEqual(list(s.each()), [])
1781
1782         s = s1 + s2
1783         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1784         s = s1 + s3
1785         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1786         s = s2 + s3
1787         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1788
1789         s = Spans(s1)
1790         s -= s2
1791         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1792         s = Spans(s1)
1793         s -= s3
1794         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1795         s = Spans(s2)
1796         s -= s3
1797         self.failUnlessEqual(list(s.each()), [5,6,7])
1798
1799         s = Spans(s1)
1800         s += s2
1801         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1802         s = Spans(s1)
1803         s += s3
1804         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1805         s = Spans(s2)
1806         s += s3
1807         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1808
1809     def test_random(self):
1810         # attempt to increase coverage of corner cases by comparing behavior
1811         # of a simple-but-slow model implementation against the
1812         # complex-but-fast actual implementation, in a large number of random
1813         # operations
1814         S1 = SimpleSpans
1815         S2 = Spans
1816         s1 = S1(); s2 = S2()
1817         seed = ""
1818         def _create(subseed):
1819             ns1 = S1(); ns2 = S2()
1820             for i in range(10):
1821                 what = _hash(subseed+str(i)).hexdigest()
1822                 start = int(what[2:4], 16)
1823                 length = max(1,int(what[5:6], 16))
1824                 ns1.add(start, length); ns2.add(start, length)
1825             return ns1, ns2
1826
1827         #print
1828         for i in range(1000):
1829             what = _hash(seed+str(i)).hexdigest()
1830             op = what[0]
1831             subop = what[1]
1832             start = int(what[2:4], 16)
1833             length = max(1,int(what[5:6], 16))
1834             #print what
1835             if op in "0":
1836                 if subop in "01234":
1837                     s1 = S1(); s2 = S2()
1838                 elif subop in "5678":
1839                     s1 = S1(start, length); s2 = S2(start, length)
1840                 else:
1841                     s1 = S1(s1); s2 = S2(s2)
1842                 #print "s2 = %s" % s2.dump()
1843             elif op in "123":
1844                 #print "s2.add(%d,%d)" % (start, length)
1845                 s1.add(start, length); s2.add(start, length)
1846             elif op in "456":
1847                 #print "s2.remove(%d,%d)" % (start, length)
1848                 s1.remove(start, length); s2.remove(start, length)
1849             elif op in "78":
1850                 ns1, ns2 = _create(what[7:11])
1851                 #print "s2 + %s" % ns2.dump()
1852                 s1 = s1 + ns1; s2 = s2 + ns2
1853             elif op in "9a":
1854                 ns1, ns2 = _create(what[7:11])
1855                 #print "%s - %s" % (s2.dump(), ns2.dump())
1856                 s1 = s1 - ns1; s2 = s2 - ns2
1857             elif op in "bc":
1858                 ns1, ns2 = _create(what[7:11])
1859                 #print "s2 += %s" % ns2.dump()
1860                 s1 += ns1; s2 += ns2
1861             elif op in "de":
1862                 ns1, ns2 = _create(what[7:11])
1863                 #print "%s -= %s" % (s2.dump(), ns2.dump())
1864                 s1 -= ns1; s2 -= ns2
1865             else:
1866                 ns1, ns2 = _create(what[7:11])
1867                 #print "%s &= %s" % (s2.dump(), ns2.dump())
1868                 s1 = s1 & ns1; s2 = s2 & ns2
1869             #print "s2 now %s" % s2.dump()
1870             self.failUnlessEqual(list(s1.each()), list(s2.each()))
1871             self.failUnlessEqual(s1.len(), s2.len())
1872             self.failUnlessEqual(bool(s1), bool(s2))
1873             self.failUnlessEqual(list(s1), list(s2))
1874             for j in range(10):
1875                 what = _hash(what[12:14]+str(j)).hexdigest()
1876                 start = int(what[2:4], 16)
1877                 length = max(1, int(what[5:6], 16))
1878                 span = (start, length)
1879                 self.failUnlessEqual(bool(span in s1), bool(span in s2))
1880
1881
1882     # s()
1883     # s(start,length)
1884     # s(s0)
1885     # s.add(start,length) : returns s
1886     # s.remove(start,length)
1887     # s.each() -> list of byte offsets, mostly for testing
1888     # list(s) -> list of (start,length) tuples, one per span
1889     # (start,length) in s -> True if (start..start+length-1) are all members
1890     #  NOT equivalent to x in list(s)
1891     # s.len() -> number of bytes, for testing, bool(), and accounting/limiting
1892     # bool(s)  (__nonzeron__)
1893     # s = s1+s2, s1-s2, +=s1, -=s1
1894
1895     def test_overlap(self):
1896         for a in range(20):
1897             for b in range(10):
1898                 for c in range(20):
1899                     for d in range(10):
1900                         self._test_overlap(a,b,c,d)
1901
1902     def _test_overlap(self, a, b, c, d):
1903         s1 = set(range(a,a+b))
1904         s2 = set(range(c,c+d))
1905         #print "---"
1906         #self._show_overlap(s1, "1")
1907         #self._show_overlap(s2, "2")
1908         o = overlap(a,b,c,d)
1909         expected = s1.intersection(s2)
1910         if not expected:
1911             self.failUnlessEqual(o, None)
1912         else:
1913             start,length = o
1914             so = set(range(start,start+length))
1915             #self._show(so, "o")
1916             self.failUnlessEqual(so, expected)
1917
1918     def _show_overlap(self, s, c):
1919         import sys
1920         out = sys.stdout
1921         if s:
1922             for i in range(max(s)):
1923                 if i in s:
1924                     out.write(c)
1925                 else:
1926                     out.write(" ")
1927         out.write("\n")
1928
1929 def extend(s, start, length, fill):
1930     if len(s) >= start+length:
1931         return s
1932     assert len(fill) == 1
1933     return s + fill*(start+length-len(s))
1934
1935 def replace(s, start, data):
1936     assert len(s) >= start+len(data)
1937     return s[:start] + data + s[start+len(data):]
1938
1939 class SimpleDataSpans:
1940     def __init__(self, other=None):
1941         self.missing = "" # "1" where missing, "0" where found
1942         self.data = ""
1943         if other:
1944             for (start, data) in other.get_chunks():
1945                 self.add(start, data)
1946
1947     def __nonzero__(self): # this gets us bool()
1948         return self.len()
1949     def len(self):
1950         return len(self.missing.replace("1", ""))
1951     def _dump(self):
1952         return [i for (i,c) in enumerate(self.missing) if c == "0"]
1953     def _have(self, start, length):
1954         m = self.missing[start:start+length]
1955         if not m or len(m)<length or int(m):
1956             return False
1957         return True
1958     def get_chunks(self):
1959         for i in self._dump():
1960             yield (i, self.data[i])
1961     def get_spans(self):
1962         return SimpleSpans([(start,len(data))
1963                             for (start,data) in self.get_chunks()])
1964     def get(self, start, length):
1965         if self._have(start, length):
1966             return self.data[start:start+length]
1967         return None
1968     def pop(self, start, length):
1969         data = self.get(start, length)
1970         if data:
1971             self.remove(start, length)
1972         return data
1973     def remove(self, start, length):
1974         self.missing = replace(extend(self.missing, start, length, "1"),
1975                                start, "1"*length)
1976     def add(self, start, data):
1977         self.missing = replace(extend(self.missing, start, len(data), "1"),
1978                                start, "0"*len(data))
1979         self.data = replace(extend(self.data, start, len(data), " "),
1980                             start, data)
1981
1982
1983 class StringSpans(unittest.TestCase):
1984     def do_basic(self, klass):
1985         ds = klass()
1986         self.failUnlessEqual(ds.len(), 0)
1987         self.failUnlessEqual(list(ds._dump()), [])
1988         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 0)
1989         s1 = ds.get_spans()
1990         self.failUnlessEqual(ds.get(0, 4), None)
1991         self.failUnlessEqual(ds.pop(0, 4), None)
1992         ds.remove(0, 4)
1993
1994         ds.add(2, "four")
1995         self.failUnlessEqual(ds.len(), 4)
1996         self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
1997         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1998         s1 = ds.get_spans()
1999         self.failUnless((2,2) in s1)
2000         self.failUnlessEqual(ds.get(0, 4), None)
2001         self.failUnlessEqual(ds.pop(0, 4), None)
2002         self.failUnlessEqual(ds.get(4, 4), None)
2003
2004         ds2 = klass(ds)
2005         self.failUnlessEqual(ds2.len(), 4)
2006         self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
2007         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 4)
2008         self.failUnlessEqual(ds2.get(0, 4), None)
2009         self.failUnlessEqual(ds2.pop(0, 4), None)
2010         self.failUnlessEqual(ds2.pop(2, 3), "fou")
2011         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 1)
2012         self.failUnlessEqual(ds2.get(2, 3), None)
2013         self.failUnlessEqual(ds2.get(5, 1), "r")
2014         self.failUnlessEqual(ds.get(2, 3), "fou")
2015         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
2016
2017         ds.add(0, "23")
2018         self.failUnlessEqual(ds.len(), 6)
2019         self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
2020         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 6)
2021         self.failUnlessEqual(ds.get(0, 4), "23fo")
2022         self.failUnlessEqual(ds.pop(0, 4), "23fo")
2023         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 2)
2024         self.failUnlessEqual(ds.get(0, 4), None)
2025         self.failUnlessEqual(ds.pop(0, 4), None)
2026
2027         ds = klass()
2028         ds.add(2, "four")
2029         ds.add(3, "ea")
2030         self.failUnlessEqual(ds.get(2, 4), "fear")
2031
2032         ds = klass()
2033         ds.add(2L, "four")
2034         ds.add(3L, "ea")
2035         self.failUnlessEqual(ds.get(2L, 4L), "fear")
2036
2037
2038     def do_scan(self, klass):
2039         # do a test with gaps and spans of size 1 and 2
2040         #  left=(1,11) * right=(1,11) * gapsize=(1,2)
2041         # 111, 112, 121, 122, 211, 212, 221, 222
2042         #    211
2043         #      121
2044         #         112
2045         #            212
2046         #               222
2047         #                   221
2048         #                      111
2049         #                        122
2050         #  11 1  1 11 11  11  1 1  111
2051         # 0123456789012345678901234567
2052         # abcdefghijklmnopqrstuvwxyz-=
2053         pieces = [(1, "bc"),
2054                   (4, "e"),
2055                   (7, "h"),
2056                   (9, "jk"),
2057                   (12, "mn"),
2058                   (16, "qr"),
2059                   (20, "u"),
2060                   (22, "w"),
2061                   (25, "z-="),
2062                   ]
2063         p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
2064         S = "abcdefghijklmnopqrstuvwxyz-="
2065         # TODO: when adding data, add capital letters, to make sure we aren't
2066         # just leaving the old data in place
2067         l = len(S)
2068         def base():
2069             ds = klass()
2070             for start, data in pieces:
2071                 ds.add(start, data)
2072             return ds
2073         def dump(s):
2074             p = set(s._dump())
2075             d = "".join([((i not in p) and " " or S[i]) for i in range(l)])
2076             assert len(d) == l
2077             return d
2078         DEBUG = False
2079         for start in range(0, l):
2080             for end in range(start+1, l):
2081                 # add [start-end) to the baseline
2082                 which = "%d-%d" % (start, end-1)
2083                 p_added = set(range(start, end))
2084                 b = base()
2085                 if DEBUG:
2086                     print
2087                     print dump(b), which
2088                     add = klass(); add.add(start, S[start:end])
2089                     print dump(add)
2090                 b.add(start, S[start:end])
2091                 if DEBUG:
2092                     print dump(b)
2093                 # check that the new span is there
2094                 d = b.get(start, end-start)
2095                 self.failUnlessEqual(d, S[start:end], which)
2096                 # check that all the original pieces are still there
2097                 for t_start, t_data in pieces:
2098                     t_len = len(t_data)
2099                     self.failUnlessEqual(b.get(t_start, t_len),
2100                                          S[t_start:t_start+t_len],
2101                                          "%s %d+%d" % (which, t_start, t_len))
2102                 # check that a lot of subspans are mostly correct
2103                 for t_start in range(l):
2104                     for t_len in range(1,4):
2105                         d = b.get(t_start, t_len)
2106                         if d is not None:
2107                             which2 = "%s+(%d-%d)" % (which, t_start,
2108                                                      t_start+t_len-1)
2109                             self.failUnlessEqual(d, S[t_start:t_start+t_len],
2110                                                  which2)
2111                         # check that removing a subspan gives the right value
2112                         b2 = klass(b)
2113                         b2.remove(t_start, t_len)
2114                         removed = set(range(t_start, t_start+t_len))
2115                         for i in range(l):
2116                             exp = (((i in p_elements) or (i in p_added))
2117                                    and (i not in removed))
2118                             which2 = "%s-(%d-%d)" % (which, t_start,
2119                                                      t_start+t_len-1)
2120                             self.failUnlessEqual(bool(b2.get(i, 1)), exp,
2121                                                  which2+" %d" % i)
2122
2123     def test_test(self):
2124         self.do_basic(SimpleDataSpans)
2125         self.do_scan(SimpleDataSpans)
2126
2127     def test_basic(self):
2128         self.do_basic(DataSpans)
2129         self.do_scan(DataSpans)
2130
2131     def test_random(self):
2132         # attempt to increase coverage of corner cases by comparing behavior
2133         # of a simple-but-slow model implementation against the
2134         # complex-but-fast actual implementation, in a large number of random
2135         # operations
2136         S1 = SimpleDataSpans
2137         S2 = DataSpans
2138         s1 = S1(); s2 = S2()
2139         seed = ""
2140         def _randstr(length, seed):
2141             created = 0
2142             pieces = []
2143             while created < length:
2144                 piece = _hash(seed + str(created)).hexdigest()
2145                 pieces.append(piece)
2146                 created += len(piece)
2147             return "".join(pieces)[:length]
2148         def _create(subseed):
2149             ns1 = S1(); ns2 = S2()
2150             for i in range(10):
2151                 what = _hash(subseed+str(i)).hexdigest()
2152                 start = int(what[2:4], 16)
2153                 length = max(1,int(what[5:6], 16))
2154                 ns1.add(start, _randstr(length, what[7:9]));
2155                 ns2.add(start, _randstr(length, what[7:9]))
2156             return ns1, ns2
2157
2158         #print
2159         for i in range(1000):
2160             what = _hash(seed+str(i)).hexdigest()
2161             op = what[0]
2162             subop = what[1]
2163             start = int(what[2:4], 16)
2164             length = max(1,int(what[5:6], 16))
2165             #print what
2166             if op in "0":
2167                 if subop in "0123456":
2168                     s1 = S1(); s2 = S2()
2169                 else:
2170                     s1, s2 = _create(what[7:11])
2171                 #print "s2 = %s" % list(s2._dump())
2172             elif op in "123456":
2173                 #print "s2.add(%d,%d)" % (start, length)
2174                 s1.add(start, _randstr(length, what[7:9]));
2175                 s2.add(start, _randstr(length, what[7:9]))
2176             elif op in "789abc":
2177                 #print "s2.remove(%d,%d)" % (start, length)
2178                 s1.remove(start, length); s2.remove(start, length)
2179             else:
2180                 #print "s2.pop(%d,%d)" % (start, length)
2181                 d1 = s1.pop(start, length); d2 = s2.pop(start, length)
2182                 self.failUnlessEqual(d1, d2)
2183             #print "s1 now %s" % list(s1._dump())
2184             #print "s2 now %s" % list(s2._dump())
2185             self.failUnlessEqual(s1.len(), s2.len())
2186             self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
2187             for j in range(100):
2188                 what = _hash(what[12:14]+str(j)).hexdigest()
2189                 start = int(what[2:4], 16)
2190                 length = max(1, int(what[5:6], 16))
2191                 d1 = s1.get(start, length); d2 = s2.get(start, length)
2192                 self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))