]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_util.py
fileutil: copy in the get_disk_stats() and get_available_space() functions from stora...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_util.py
1
2 def foo(): pass # keep the line number constant
3
4 import os, time, sys
5 from StringIO import StringIO
6 from twisted.trial import unittest
7 from twisted.internet import defer, reactor
8 from twisted.python.failure import Failure
9 from twisted.python import log
10 from pycryptopp.hash.sha256 import SHA256 as _hash
11
12 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
13 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
14 from allmydata.util import limiter, time_format, pollmixin, cachedir
15 from allmydata.util import statistics, dictutil, pipeline
16 from allmydata.util import log as tahoe_log
17 from allmydata.util.spans import Spans, overlap, DataSpans
18
19 class Base32(unittest.TestCase):
20     def test_b2a_matches_Pythons(self):
21         import base64
22         y = "\x12\x34\x45\x67\x89\x0a\xbc\xde\xf0"
23         x = base64.b32encode(y)
24         while x and x[-1] == '=':
25             x = x[:-1]
26         x = x.lower()
27         self.failUnlessEqual(base32.b2a(y), x)
28     def test_b2a(self):
29         self.failUnlessEqual(base32.b2a("\x12\x34"), "ci2a")
30     def test_b2a_or_none(self):
31         self.failUnlessEqual(base32.b2a_or_none(None), None)
32         self.failUnlessEqual(base32.b2a_or_none("\x12\x34"), "ci2a")
33     def test_a2b(self):
34         self.failUnlessEqual(base32.a2b("ci2a"), "\x12\x34")
35         self.failUnlessRaises(AssertionError, base32.a2b, "b0gus")
36
37 class IDLib(unittest.TestCase):
38     def test_nodeid_b2a(self):
39         self.failUnlessEqual(idlib.nodeid_b2a("\x00"*20), "a"*32)
40
41 class NoArgumentException(Exception):
42     def __init__(self):
43         pass
44
45 class HumanReadable(unittest.TestCase):
46     def test_repr(self):
47         hr = humanreadable.hr
48         self.failUnlessEqual(hr(foo), "<foo() at test_util.py:2>")
49         self.failUnlessEqual(hr(self.test_repr),
50                              "<bound method HumanReadable.test_repr of <allmydata.test.test_util.HumanReadable testMethod=test_repr>>")
51         self.failUnlessEqual(hr(1L), "1")
52         self.failUnlessEqual(hr(10**40),
53                              "100000000000000000...000000000000000000")
54         self.failUnlessEqual(hr(self), "<allmydata.test.test_util.HumanReadable testMethod=test_repr>")
55         self.failUnlessEqual(hr([1,2]), "[1, 2]")
56         self.failUnlessEqual(hr({1:2}), "{1:2}")
57         try:
58             raise ValueError
59         except Exception, e:
60             self.failUnless(
61                 hr(e) == "<ValueError: ()>" # python-2.4
62                 or hr(e) == "ValueError()") # python-2.5
63         try:
64             raise ValueError("oops")
65         except Exception, e:
66             self.failUnless(
67                 hr(e) == "<ValueError: 'oops'>" # python-2.4
68                 or hr(e) == "ValueError('oops',)") # python-2.5
69         try:
70             raise NoArgumentException
71         except Exception, e:
72             self.failUnless(
73                 hr(e) == "<NoArgumentException>" # python-2.4
74                 or hr(e) == "NoArgumentException()") # python-2.5
75
76
77 class MyList(list):
78     pass
79
80 class Math(unittest.TestCase):
81     def test_div_ceil(self):
82         f = mathutil.div_ceil
83         self.failUnlessEqual(f(0, 1), 0)
84         self.failUnlessEqual(f(0, 2), 0)
85         self.failUnlessEqual(f(0, 3), 0)
86         self.failUnlessEqual(f(1, 3), 1)
87         self.failUnlessEqual(f(2, 3), 1)
88         self.failUnlessEqual(f(3, 3), 1)
89         self.failUnlessEqual(f(4, 3), 2)
90         self.failUnlessEqual(f(5, 3), 2)
91         self.failUnlessEqual(f(6, 3), 2)
92         self.failUnlessEqual(f(7, 3), 3)
93
94     def test_next_multiple(self):
95         f = mathutil.next_multiple
96         self.failUnlessEqual(f(5, 1), 5)
97         self.failUnlessEqual(f(5, 2), 6)
98         self.failUnlessEqual(f(5, 3), 6)
99         self.failUnlessEqual(f(5, 4), 8)
100         self.failUnlessEqual(f(5, 5), 5)
101         self.failUnlessEqual(f(5, 6), 6)
102         self.failUnlessEqual(f(32, 1), 32)
103         self.failUnlessEqual(f(32, 2), 32)
104         self.failUnlessEqual(f(32, 3), 33)
105         self.failUnlessEqual(f(32, 4), 32)
106         self.failUnlessEqual(f(32, 5), 35)
107         self.failUnlessEqual(f(32, 6), 36)
108         self.failUnlessEqual(f(32, 7), 35)
109         self.failUnlessEqual(f(32, 8), 32)
110         self.failUnlessEqual(f(32, 9), 36)
111         self.failUnlessEqual(f(32, 10), 40)
112         self.failUnlessEqual(f(32, 11), 33)
113         self.failUnlessEqual(f(32, 12), 36)
114         self.failUnlessEqual(f(32, 13), 39)
115         self.failUnlessEqual(f(32, 14), 42)
116         self.failUnlessEqual(f(32, 15), 45)
117         self.failUnlessEqual(f(32, 16), 32)
118         self.failUnlessEqual(f(32, 17), 34)
119         self.failUnlessEqual(f(32, 18), 36)
120         self.failUnlessEqual(f(32, 589), 589)
121
122     def test_pad_size(self):
123         f = mathutil.pad_size
124         self.failUnlessEqual(f(0, 4), 0)
125         self.failUnlessEqual(f(1, 4), 3)
126         self.failUnlessEqual(f(2, 4), 2)
127         self.failUnlessEqual(f(3, 4), 1)
128         self.failUnlessEqual(f(4, 4), 0)
129         self.failUnlessEqual(f(5, 4), 3)
130
131     def test_is_power_of_k(self):
132         f = mathutil.is_power_of_k
133         for i in range(1, 100):
134             if i in (1, 2, 4, 8, 16, 32, 64):
135                 self.failUnless(f(i, 2), "but %d *is* a power of 2" % i)
136             else:
137                 self.failIf(f(i, 2), "but %d is *not* a power of 2" % i)
138         for i in range(1, 100):
139             if i in (1, 3, 9, 27, 81):
140                 self.failUnless(f(i, 3), "but %d *is* a power of 3" % i)
141             else:
142                 self.failIf(f(i, 3), "but %d is *not* a power of 3" % i)
143
144     def test_next_power_of_k(self):
145         f = mathutil.next_power_of_k
146         self.failUnlessEqual(f(0,2), 1)
147         self.failUnlessEqual(f(1,2), 1)
148         self.failUnlessEqual(f(2,2), 2)
149         self.failUnlessEqual(f(3,2), 4)
150         self.failUnlessEqual(f(4,2), 4)
151         for i in range(5, 8): self.failUnlessEqual(f(i,2), 8, "%d" % i)
152         for i in range(9, 16): self.failUnlessEqual(f(i,2), 16, "%d" % i)
153         for i in range(17, 32): self.failUnlessEqual(f(i,2), 32, "%d" % i)
154         for i in range(33, 64): self.failUnlessEqual(f(i,2), 64, "%d" % i)
155         for i in range(65, 100): self.failUnlessEqual(f(i,2), 128, "%d" % i)
156
157         self.failUnlessEqual(f(0,3), 1)
158         self.failUnlessEqual(f(1,3), 1)
159         self.failUnlessEqual(f(2,3), 3)
160         self.failUnlessEqual(f(3,3), 3)
161         for i in range(4, 9): self.failUnlessEqual(f(i,3), 9, "%d" % i)
162         for i in range(10, 27): self.failUnlessEqual(f(i,3), 27, "%d" % i)
163         for i in range(28, 81): self.failUnlessEqual(f(i,3), 81, "%d" % i)
164         for i in range(82, 200): self.failUnlessEqual(f(i,3), 243, "%d" % i)
165
166     def test_ave(self):
167         f = mathutil.ave
168         self.failUnlessEqual(f([1,2,3]), 2)
169         self.failUnlessEqual(f([0,0,0,4]), 1)
170         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
171
172     def test_round_sigfigs(self):
173         f = mathutil.round_sigfigs
174         self.failUnlessEqual(f(22.0/3, 4), 7.3330000000000002)
175
176 class Statistics(unittest.TestCase):
177     def should_assert(self, msg, func, *args, **kwargs):
178         try:
179             func(*args, **kwargs)
180             self.fail(msg)
181         except AssertionError:
182             pass
183
184     def failUnlessListEqual(self, a, b, msg = None):
185         self.failUnlessEqual(len(a), len(b))
186         for i in range(len(a)):
187             self.failUnlessEqual(a[i], b[i], msg)
188
189     def failUnlessListAlmostEqual(self, a, b, places = 7, msg = None):
190         self.failUnlessEqual(len(a), len(b))
191         for i in range(len(a)):
192             self.failUnlessAlmostEqual(a[i], b[i], places, msg)
193
194     def test_binomial_coeff(self):
195         f = statistics.binomial_coeff
196         self.failUnlessEqual(f(20, 0), 1)
197         self.failUnlessEqual(f(20, 1), 20)
198         self.failUnlessEqual(f(20, 2), 190)
199         self.failUnlessEqual(f(20, 8), f(20, 12))
200         self.should_assert("Should assert if n < k", f, 2, 3)
201
202     def test_binomial_distribution_pmf(self):
203         f = statistics.binomial_distribution_pmf
204
205         pmf_comp = f(2, .1)
206         pmf_stat = [0.81, 0.18, 0.01]
207         self.failUnlessListAlmostEqual(pmf_comp, pmf_stat)
208
209         # Summing across a PMF should give the total probability 1
210         self.failUnlessAlmostEqual(sum(pmf_comp), 1)
211         self.should_assert("Should assert if not 0<=p<=1", f, 1, -1)
212         self.should_assert("Should assert if n < 1", f, 0, .1)
213
214         out = StringIO()
215         statistics.print_pmf(pmf_comp, out=out)
216         lines = out.getvalue().splitlines()
217         self.failUnlessEqual(lines[0], "i=0: 0.81")
218         self.failUnlessEqual(lines[1], "i=1: 0.18")
219         self.failUnlessEqual(lines[2], "i=2: 0.01")
220
221     def test_survival_pmf(self):
222         f = statistics.survival_pmf
223         # Cross-check binomial-distribution method against convolution
224         # method.
225         p_list = [.9999] * 100 + [.99] * 50 + [.8] * 20
226         pmf1 = statistics.survival_pmf_via_conv(p_list)
227         pmf2 = statistics.survival_pmf_via_bd(p_list)
228         self.failUnlessListAlmostEqual(pmf1, pmf2)
229         self.failUnlessTrue(statistics.valid_pmf(pmf1))
230         self.should_assert("Should assert if p_i > 1", f, [1.1]);
231         self.should_assert("Should assert if p_i < 0", f, [-.1]);
232
233     def test_repair_count_pmf(self):
234         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
235         repair_pmf = statistics.repair_count_pmf(survival_pmf, 3)
236         # repair_pmf[0] == sum(survival_pmf[0,1,2,5])
237         # repair_pmf[1] == survival_pmf[4]
238         # repair_pmf[2] = survival_pmf[3]
239         self.failUnlessListAlmostEqual(repair_pmf,
240                                        [0.00001 + 0.00045 + 0.0081 + 0.59049,
241                                         .32805,
242                                         .0729,
243                                         0, 0, 0])
244
245     def test_repair_cost(self):
246         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
247         bwcost = statistics.bandwidth_cost_function
248         cost = statistics.mean_repair_cost(bwcost, 1000,
249                                            survival_pmf, 3, ul_dl_ratio=1.0)
250         self.failUnlessAlmostEqual(cost, 558.90)
251         cost = statistics.mean_repair_cost(bwcost, 1000,
252                                            survival_pmf, 3, ul_dl_ratio=8.0)
253         self.failUnlessAlmostEqual(cost, 1664.55)
254
255         # I haven't manually checked the math beyond here -warner
256         cost = statistics.eternal_repair_cost(bwcost, 1000,
257                                               survival_pmf, 3,
258                                               discount_rate=0, ul_dl_ratio=1.0)
259         self.failUnlessAlmostEqual(cost, 65292.056074766246)
260         cost = statistics.eternal_repair_cost(bwcost, 1000,
261                                               survival_pmf, 3,
262                                               discount_rate=0.05,
263                                               ul_dl_ratio=1.0)
264         self.failUnlessAlmostEqual(cost, 9133.6097158191551)
265
266     def test_convolve(self):
267         f = statistics.convolve
268         v1 = [ 1, 2, 3 ]
269         v2 = [ 4, 5, 6 ]
270         v3 = [ 7, 8 ]
271         v1v2result = [ 4, 13, 28, 27, 18 ]
272         # Convolution is commutative
273         r1 = f(v1, v2)
274         r2 = f(v2, v1)
275         self.failUnlessListEqual(r1, r2, "Convolution should be commutative")
276         self.failUnlessListEqual(r1, v1v2result, "Didn't match known result")
277         # Convolution is associative
278         r1 = f(f(v1, v2), v3)
279         r2 = f(v1, f(v2, v3))
280         self.failUnlessListEqual(r1, r2, "Convolution should be associative")
281         # Convolution is distributive
282         r1 = f(v3, [ a + b for a, b in zip(v1, v2) ])
283         tmp1 = f(v3, v1)
284         tmp2 = f(v3, v2)
285         r2 = [ a + b for a, b in zip(tmp1, tmp2) ]
286         self.failUnlessListEqual(r1, r2, "Convolution should be distributive")
287         # Convolution is scalar multiplication associative
288         tmp1 = f(v1, v2)
289         r1 = [ a * 4 for a in tmp1 ]
290         tmp2 = [ a * 4 for a in v1 ]
291         r2 = f(tmp2, v2)
292         self.failUnlessListEqual(r1, r2, "Convolution should be scalar multiplication associative")
293
294     def test_find_k(self):
295         f = statistics.find_k
296         g = statistics.pr_file_loss
297         plist = [.9] * 10 + [.8] * 10 # N=20
298         t = .0001
299         k = f(plist, t)
300         self.failUnlessEqual(k, 10)
301         self.failUnless(g(plist, k) < t)
302
303     def test_pr_file_loss(self):
304         f = statistics.pr_file_loss
305         plist = [.5] * 10
306         self.failUnlessEqual(f(plist, 3), .0546875)
307
308     def test_pr_backup_file_loss(self):
309         f = statistics.pr_backup_file_loss
310         plist = [.5] * 10
311         self.failUnlessEqual(f(plist, .5, 3), .02734375)
312
313
314 class Asserts(unittest.TestCase):
315     def should_assert(self, func, *args, **kwargs):
316         try:
317             func(*args, **kwargs)
318         except AssertionError, e:
319             return str(e)
320         except Exception, e:
321             self.fail("assert failed with non-AssertionError: %s" % e)
322         self.fail("assert was not caught")
323
324     def should_not_assert(self, func, *args, **kwargs):
325         try:
326             func(*args, **kwargs)
327         except AssertionError, e:
328             self.fail("assertion fired when it should not have: %s" % e)
329         except Exception, e:
330             self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
331         return # we're happy
332
333
334     def test_assert(self):
335         f = assertutil._assert
336         self.should_assert(f)
337         self.should_assert(f, False)
338         self.should_not_assert(f, True)
339
340         m = self.should_assert(f, False, "message")
341         self.failUnlessEqual(m, "'message' <type 'str'>", m)
342         m = self.should_assert(f, False, "message1", othermsg=12)
343         self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
344         m = self.should_assert(f, False, othermsg="message2")
345         self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
346
347     def test_precondition(self):
348         f = assertutil.precondition
349         self.should_assert(f)
350         self.should_assert(f, False)
351         self.should_not_assert(f, True)
352
353         m = self.should_assert(f, False, "message")
354         self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
355         m = self.should_assert(f, False, "message1", othermsg=12)
356         self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
357         m = self.should_assert(f, False, othermsg="message2")
358         self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
359
360     def test_postcondition(self):
361         f = assertutil.postcondition
362         self.should_assert(f)
363         self.should_assert(f, False)
364         self.should_not_assert(f, True)
365
366         m = self.should_assert(f, False, "message")
367         self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
368         m = self.should_assert(f, False, "message1", othermsg=12)
369         self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
370         m = self.should_assert(f, False, othermsg="message2")
371         self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
372
373 class FileUtil(unittest.TestCase):
374     def mkdir(self, basedir, path, mode=0777):
375         fn = os.path.join(basedir, path)
376         fileutil.make_dirs(fn, mode)
377
378     def touch(self, basedir, path, mode=None, data="touch\n"):
379         fn = os.path.join(basedir, path)
380         f = open(fn, "w")
381         f.write(data)
382         f.close()
383         if mode is not None:
384             os.chmod(fn, mode)
385
386     def test_rm_dir(self):
387         basedir = "util/FileUtil/test_rm_dir"
388         fileutil.make_dirs(basedir)
389         # create it again to test idempotency
390         fileutil.make_dirs(basedir)
391         d = os.path.join(basedir, "doomed")
392         self.mkdir(d, "a/b")
393         self.touch(d, "a/b/1.txt")
394         self.touch(d, "a/b/2.txt", 0444)
395         self.touch(d, "a/b/3.txt", 0)
396         self.mkdir(d, "a/c")
397         self.touch(d, "a/c/1.txt")
398         self.touch(d, "a/c/2.txt", 0444)
399         self.touch(d, "a/c/3.txt", 0)
400         os.chmod(os.path.join(d, "a/c"), 0444)
401         self.mkdir(d, "a/d")
402         self.touch(d, "a/d/1.txt")
403         self.touch(d, "a/d/2.txt", 0444)
404         self.touch(d, "a/d/3.txt", 0)
405         os.chmod(os.path.join(d, "a/d"), 0)
406
407         fileutil.rm_dir(d)
408         self.failIf(os.path.exists(d))
409         # remove it again to test idempotency
410         fileutil.rm_dir(d)
411
412     def test_remove_if_possible(self):
413         basedir = "util/FileUtil/test_remove_if_possible"
414         fileutil.make_dirs(basedir)
415         self.touch(basedir, "here")
416         fn = os.path.join(basedir, "here")
417         fileutil.remove_if_possible(fn)
418         self.failIf(os.path.exists(fn))
419         fileutil.remove_if_possible(fn) # should be idempotent
420         fileutil.rm_dir(basedir)
421         fileutil.remove_if_possible(fn) # should survive errors
422
423     def test_open_or_create(self):
424         basedir = "util/FileUtil/test_open_or_create"
425         fileutil.make_dirs(basedir)
426         fn = os.path.join(basedir, "here")
427         f = fileutil.open_or_create(fn)
428         f.write("stuff.")
429         f.close()
430         f = fileutil.open_or_create(fn)
431         f.seek(0, 2)
432         f.write("more.")
433         f.close()
434         f = open(fn, "r")
435         data = f.read()
436         f.close()
437         self.failUnlessEqual(data, "stuff.more.")
438
439     def test_NamedTemporaryDirectory(self):
440         basedir = "util/FileUtil/test_NamedTemporaryDirectory"
441         fileutil.make_dirs(basedir)
442         td = fileutil.NamedTemporaryDirectory(dir=basedir)
443         name = td.name
444         self.failUnless(basedir in name)
445         self.failUnless(basedir in repr(td))
446         self.failUnless(os.path.isdir(name))
447         del td
448         # it is conceivable that we need to force gc here, but I'm not sure
449         self.failIf(os.path.isdir(name))
450
451     def test_rename(self):
452         basedir = "util/FileUtil/test_rename"
453         fileutil.make_dirs(basedir)
454         self.touch(basedir, "here")
455         fn = os.path.join(basedir, "here")
456         fn2 = os.path.join(basedir, "there")
457         fileutil.rename(fn, fn2)
458         self.failIf(os.path.exists(fn))
459         self.failUnless(os.path.exists(fn2))
460
461     def test_du(self):
462         basedir = "util/FileUtil/test_du"
463         fileutil.make_dirs(basedir)
464         d = os.path.join(basedir, "space-consuming")
465         self.mkdir(d, "a/b")
466         self.touch(d, "a/b/1.txt", data="a"*10)
467         self.touch(d, "a/b/2.txt", data="b"*11)
468         self.mkdir(d, "a/c")
469         self.touch(d, "a/c/1.txt", data="c"*12)
470         self.touch(d, "a/c/2.txt", data="d"*13)
471
472         used = fileutil.du(basedir)
473         self.failUnlessEqual(10+11+12+13, used)
474
475     def test_abspath_expanduser_unicode(self):
476         self.failUnlessRaises(AssertionError, fileutil.abspath_expanduser_unicode, "bytestring")
477
478         saved_cwd = os.path.normpath(os.getcwdu())
479         abspath_cwd = fileutil.abspath_expanduser_unicode(u".")
480         self.failUnless(isinstance(saved_cwd, unicode), saved_cwd)
481         self.failUnless(isinstance(abspath_cwd, unicode), abspath_cwd)
482         self.failUnlessEqual(abspath_cwd, saved_cwd)
483
484         # adapted from <http://svn.python.org/view/python/branches/release26-maint/Lib/test/test_posixpath.py?view=markup&pathrev=78279#test_abspath>
485
486         self.failUnlessIn(u"foo", fileutil.abspath_expanduser_unicode(u"foo"))
487         self.failIfIn(u"~", fileutil.abspath_expanduser_unicode(u"~"))
488
489         cwds = ['cwd']
490         try:
491             cwds.append(u'\xe7w\xf0'.encode(sys.getfilesystemencoding()
492                                             or 'ascii'))
493         except UnicodeEncodeError:
494             pass # the cwd can't be encoded -- test with ascii cwd only
495
496         for cwd in cwds:
497             try:
498                 os.mkdir(cwd)
499                 os.chdir(cwd)
500                 for upath in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\', u'~'):
501                     uabspath = fileutil.abspath_expanduser_unicode(upath)
502                     self.failUnless(isinstance(uabspath, unicode), uabspath)
503             finally:
504                 os.chdir(saved_cwd)
505
506     def test_disk_stats(self):
507         avail = fileutil.get_available_space('.', 2**14)
508         if avail == 0:
509             raise unittest.SkipTest("This test will spuriously fail there is no disk space left.")
510
511         disk = fileutil.get_disk_stats('.', 2**13)
512         self.failUnless(disk['total'] > 0, disk['total'])
513         self.failUnless(disk['used'] > 0, disk['used'])
514         self.failUnless(disk['free_for_root'] > 0, disk['free_for_root'])
515         self.failUnless(disk['free_for_nonroot'] > 0, disk['free_for_nonroot'])
516         self.failUnless(disk['avail'] > 0, disk['avail'])
517
518     def test_disk_stats_avail_nonnegative(self):
519         # This test will spuriously fail if you have more than 2^128
520         # bytes of available space on your filesystem.
521         disk = fileutil.get_disk_stats('.', 2**128)
522         self.failUnlessEqual(disk['avail'], 0)
523
524 class PollMixinTests(unittest.TestCase):
525     def setUp(self):
526         self.pm = pollmixin.PollMixin()
527
528     def test_PollMixin_True(self):
529         d = self.pm.poll(check_f=lambda : True,
530                          pollinterval=0.1)
531         return d
532
533     def test_PollMixin_False_then_True(self):
534         i = iter([False, True])
535         d = self.pm.poll(check_f=i.next,
536                          pollinterval=0.1)
537         return d
538
539     def test_timeout(self):
540         d = self.pm.poll(check_f=lambda: False,
541                          pollinterval=0.01,
542                          timeout=1)
543         def _suc(res):
544             self.fail("poll should have failed, not returned %s" % (res,))
545         def _err(f):
546             f.trap(pollmixin.TimeoutError)
547             return None # success
548         d.addCallbacks(_suc, _err)
549         return d
550
551 class DeferredUtilTests(unittest.TestCase):
552     def test_gather_results(self):
553         d1 = defer.Deferred()
554         d2 = defer.Deferred()
555         res = deferredutil.gatherResults([d1, d2])
556         d1.errback(ValueError("BAD"))
557         def _callb(res):
558             self.fail("Should have errbacked, not resulted in %s" % (res,))
559         def _errb(thef):
560             thef.trap(ValueError)
561         res.addCallbacks(_callb, _errb)
562         return res
563
564     def test_success(self):
565         d1, d2 = defer.Deferred(), defer.Deferred()
566         good = []
567         bad = []
568         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
569         dlss.addCallbacks(good.append, bad.append)
570         d1.callback(1)
571         d2.callback(2)
572         self.failUnlessEqual(good, [[1,2]])
573         self.failUnlessEqual(bad, [])
574
575     def test_failure(self):
576         d1, d2 = defer.Deferred(), defer.Deferred()
577         good = []
578         bad = []
579         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
580         dlss.addCallbacks(good.append, bad.append)
581         d1.addErrback(lambda _ignore: None)
582         d2.addErrback(lambda _ignore: None)
583         d1.callback(1)
584         d2.errback(ValueError())
585         self.failUnlessEqual(good, [])
586         self.failUnlessEqual(len(bad), 1)
587         f = bad[0]
588         self.failUnless(isinstance(f, Failure))
589         self.failUnless(f.check(ValueError))
590
591 class HashUtilTests(unittest.TestCase):
592
593     def test_random_key(self):
594         k = hashutil.random_key()
595         self.failUnlessEqual(len(k), hashutil.KEYLEN)
596
597     def test_sha256d(self):
598         h1 = hashutil.tagged_hash("tag1", "value")
599         h2 = hashutil.tagged_hasher("tag1")
600         h2.update("value")
601         h2a = h2.digest()
602         h2b = h2.digest()
603         self.failUnlessEqual(h1, h2a)
604         self.failUnlessEqual(h2a, h2b)
605
606     def test_sha256d_truncated(self):
607         h1 = hashutil.tagged_hash("tag1", "value", 16)
608         h2 = hashutil.tagged_hasher("tag1", 16)
609         h2.update("value")
610         h2 = h2.digest()
611         self.failUnlessEqual(len(h1), 16)
612         self.failUnlessEqual(len(h2), 16)
613         self.failUnlessEqual(h1, h2)
614
615     def test_chk(self):
616         h1 = hashutil.convergence_hash(3, 10, 1000, "data", "secret")
617         h2 = hashutil.convergence_hasher(3, 10, 1000, "secret")
618         h2.update("data")
619         h2 = h2.digest()
620         self.failUnlessEqual(h1, h2)
621
622     def test_hashers(self):
623         h1 = hashutil.block_hash("foo")
624         h2 = hashutil.block_hasher()
625         h2.update("foo")
626         self.failUnlessEqual(h1, h2.digest())
627
628         h1 = hashutil.uri_extension_hash("foo")
629         h2 = hashutil.uri_extension_hasher()
630         h2.update("foo")
631         self.failUnlessEqual(h1, h2.digest())
632
633         h1 = hashutil.plaintext_hash("foo")
634         h2 = hashutil.plaintext_hasher()
635         h2.update("foo")
636         self.failUnlessEqual(h1, h2.digest())
637
638         h1 = hashutil.crypttext_hash("foo")
639         h2 = hashutil.crypttext_hasher()
640         h2.update("foo")
641         self.failUnlessEqual(h1, h2.digest())
642
643         h1 = hashutil.crypttext_segment_hash("foo")
644         h2 = hashutil.crypttext_segment_hasher()
645         h2.update("foo")
646         self.failUnlessEqual(h1, h2.digest())
647
648         h1 = hashutil.plaintext_segment_hash("foo")
649         h2 = hashutil.plaintext_segment_hasher()
650         h2.update("foo")
651         self.failUnlessEqual(h1, h2.digest())
652
653     def test_constant_time_compare(self):
654         self.failUnless(hashutil.constant_time_compare("a", "a"))
655         self.failUnless(hashutil.constant_time_compare("ab", "ab"))
656         self.failIf(hashutil.constant_time_compare("a", "b"))
657         self.failIf(hashutil.constant_time_compare("a", "aa"))
658
659     def _testknown(self, hashf, expected_a, *args):
660         got = hashf(*args)
661         got_a = base32.b2a(got)
662         self.failUnlessEqual(got_a, expected_a)
663
664     def test_known_answers(self):
665         # assert backwards compatibility
666         self._testknown(hashutil.storage_index_hash, "qb5igbhcc5esa6lwqorsy7e6am", "")
667         self._testknown(hashutil.block_hash, "msjr5bh4evuh7fa3zw7uovixfbvlnstr5b65mrerwfnvjxig2jvq", "")
668         self._testknown(hashutil.uri_extension_hash, "wthsu45q7zewac2mnivoaa4ulh5xvbzdmsbuyztq2a5fzxdrnkka", "")
669         self._testknown(hashutil.plaintext_hash, "5lz5hwz3qj3af7n6e3arblw7xzutvnd3p3fjsngqjcb7utf3x3da", "")
670         self._testknown(hashutil.crypttext_hash, "itdj6e4njtkoiavlrmxkvpreosscssklunhwtvxn6ggho4rkqwga", "")
671         self._testknown(hashutil.crypttext_segment_hash, "aovy5aa7jej6ym5ikgwyoi4pxawnoj3wtaludjz7e2nb5xijb7aa", "")
672         self._testknown(hashutil.plaintext_segment_hash, "4fdgf6qruaisyukhqcmoth4t3li6bkolbxvjy4awwcpprdtva7za", "")
673         self._testknown(hashutil.convergence_hash, "3mo6ni7xweplycin6nowynw2we", 3, 10, 100, "", "converge")
674         self._testknown(hashutil.my_renewal_secret_hash, "ujhr5k5f7ypkp67jkpx6jl4p47pyta7hu5m527cpcgvkafsefm6q", "")
675         self._testknown(hashutil.my_cancel_secret_hash, "rjwzmafe2duixvqy6h47f5wfrokdziry6zhx4smew4cj6iocsfaa", "")
676         self._testknown(hashutil.file_renewal_secret_hash, "hzshk2kf33gzbd5n3a6eszkf6q6o6kixmnag25pniusyaulqjnia", "", "si")
677         self._testknown(hashutil.file_cancel_secret_hash, "bfciwvr6w7wcavsngxzxsxxaszj72dej54n4tu2idzp6b74g255q", "", "si")
678         self._testknown(hashutil.bucket_renewal_secret_hash, "e7imrzgzaoashsncacvy3oysdd2m5yvtooo4gmj4mjlopsazmvuq", "", "\x00"*20)
679         self._testknown(hashutil.bucket_cancel_secret_hash, "dvdujeyxeirj6uux6g7xcf4lvesk632aulwkzjar7srildvtqwma", "", "\x00"*20)
680         self._testknown(hashutil.hmac, "c54ypfi6pevb3nvo6ba42jtglpkry2kbdopqsi7dgrm4r7tw5sra", "tag", "")
681         self._testknown(hashutil.mutable_rwcap_key_hash, "6rvn2iqrghii5n4jbbwwqqsnqu", "iv", "wk")
682         self._testknown(hashutil.ssk_writekey_hash, "ykpgmdbpgbb6yqz5oluw2q26ye", "")
683         self._testknown(hashutil.ssk_write_enabler_master_hash, "izbfbfkoait4dummruol3gy2bnixrrrslgye6ycmkuyujnenzpia", "")
684         self._testknown(hashutil.ssk_write_enabler_hash, "fuu2dvx7g6gqu5x22vfhtyed7p4pd47y5hgxbqzgrlyvxoev62tq", "wk", "\x00"*20)
685         self._testknown(hashutil.ssk_pubkey_fingerprint_hash, "3opzw4hhm2sgncjx224qmt5ipqgagn7h5zivnfzqycvgqgmgz35q", "")
686         self._testknown(hashutil.ssk_readkey_hash, "vugid4as6qbqgeq2xczvvcedai", "")
687         self._testknown(hashutil.ssk_readkey_data_hash, "73wsaldnvdzqaf7v4pzbr2ae5a", "iv", "rk")
688         self._testknown(hashutil.ssk_storage_index_hash, "j7icz6kigb6hxrej3tv4z7ayym", "")
689
690
691 class Abbreviate(unittest.TestCase):
692     def test_time(self):
693         a = abbreviate.abbreviate_time
694         self.failUnlessEqual(a(None), "unknown")
695         self.failUnlessEqual(a(0), "0 seconds")
696         self.failUnlessEqual(a(1), "1 second")
697         self.failUnlessEqual(a(2), "2 seconds")
698         self.failUnlessEqual(a(119), "119 seconds")
699         MIN = 60
700         self.failUnlessEqual(a(2*MIN), "2 minutes")
701         self.failUnlessEqual(a(60*MIN), "60 minutes")
702         self.failUnlessEqual(a(179*MIN), "179 minutes")
703         HOUR = 60*MIN
704         self.failUnlessEqual(a(180*MIN), "3 hours")
705         self.failUnlessEqual(a(4*HOUR), "4 hours")
706         DAY = 24*HOUR
707         MONTH = 30*DAY
708         self.failUnlessEqual(a(2*DAY), "2 days")
709         self.failUnlessEqual(a(2*MONTH), "2 months")
710         YEAR = 365*DAY
711         self.failUnlessEqual(a(5*YEAR), "5 years")
712
713     def test_space(self):
714         tests_si = [(None, "unknown"),
715                     (0, "0 B"),
716                     (1, "1 B"),
717                     (999, "999 B"),
718                     (1000, "1000 B"),
719                     (1023, "1023 B"),
720                     (1024, "1.02 kB"),
721                     (20*1000, "20.00 kB"),
722                     (1024*1024, "1.05 MB"),
723                     (1000*1000, "1.00 MB"),
724                     (1000*1000*1000, "1.00 GB"),
725                     (1000*1000*1000*1000, "1.00 TB"),
726                     (1000*1000*1000*1000*1000, "1.00 PB"),
727                     (1234567890123456, "1.23 PB"),
728                     ]
729         for (x, expected) in tests_si:
730             got = abbreviate.abbreviate_space(x, SI=True)
731             self.failUnlessEqual(got, expected)
732
733         tests_base1024 = [(None, "unknown"),
734                           (0, "0 B"),
735                           (1, "1 B"),
736                           (999, "999 B"),
737                           (1000, "1000 B"),
738                           (1023, "1023 B"),
739                           (1024, "1.00 kiB"),
740                           (20*1024, "20.00 kiB"),
741                           (1000*1000, "976.56 kiB"),
742                           (1024*1024, "1.00 MiB"),
743                           (1024*1024*1024, "1.00 GiB"),
744                           (1024*1024*1024*1024, "1.00 TiB"),
745                           (1000*1000*1000*1000*1000, "909.49 TiB"),
746                           (1024*1024*1024*1024*1024, "1.00 PiB"),
747                           (1234567890123456, "1.10 PiB"),
748                     ]
749         for (x, expected) in tests_base1024:
750             got = abbreviate.abbreviate_space(x, SI=False)
751             self.failUnlessEqual(got, expected)
752
753         self.failUnlessEqual(abbreviate.abbreviate_space_both(1234567),
754                              "(1.23 MB, 1.18 MiB)")
755
756     def test_parse_space(self):
757         p = abbreviate.parse_abbreviated_size
758         self.failUnlessEqual(p(""), None)
759         self.failUnlessEqual(p(None), None)
760         self.failUnlessEqual(p("123"), 123)
761         self.failUnlessEqual(p("123B"), 123)
762         self.failUnlessEqual(p("2K"), 2000)
763         self.failUnlessEqual(p("2kb"), 2000)
764         self.failUnlessEqual(p("2KiB"), 2048)
765         self.failUnlessEqual(p("10MB"), 10*1000*1000)
766         self.failUnlessEqual(p("10MiB"), 10*1024*1024)
767         self.failUnlessEqual(p("5G"), 5*1000*1000*1000)
768         self.failUnlessEqual(p("4GiB"), 4*1024*1024*1024)
769         e = self.failUnlessRaises(ValueError, p, "12 cubits")
770         self.failUnless("12 cubits" in str(e))
771
772 class Limiter(unittest.TestCase):
773     timeout = 480 # This takes longer than 240 seconds on Francois's arm box.
774
775     def job(self, i, foo):
776         self.calls.append( (i, foo) )
777         self.simultaneous += 1
778         self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
779         d = defer.Deferred()
780         def _done():
781             self.simultaneous -= 1
782             d.callback("done %d" % i)
783         reactor.callLater(1.0, _done)
784         return d
785
786     def bad_job(self, i, foo):
787         raise ValueError("bad_job %d" % i)
788
789     def test_limiter(self):
790         self.calls = []
791         self.simultaneous = 0
792         self.peak_simultaneous = 0
793         l = limiter.ConcurrencyLimiter()
794         dl = []
795         for i in range(20):
796             dl.append(l.add(self.job, i, foo=str(i)))
797         d = defer.DeferredList(dl, fireOnOneErrback=True)
798         def _done(res):
799             self.failUnlessEqual(self.simultaneous, 0)
800             self.failUnless(self.peak_simultaneous <= 10)
801             self.failUnlessEqual(len(self.calls), 20)
802             for i in range(20):
803                 self.failUnless( (i, str(i)) in self.calls)
804         d.addCallback(_done)
805         return d
806
807     def test_errors(self):
808         self.calls = []
809         self.simultaneous = 0
810         self.peak_simultaneous = 0
811         l = limiter.ConcurrencyLimiter()
812         dl = []
813         for i in range(20):
814             dl.append(l.add(self.job, i, foo=str(i)))
815         d2 = l.add(self.bad_job, 21, "21")
816         d = defer.DeferredList(dl, fireOnOneErrback=True)
817         def _most_done(res):
818             results = []
819             for (success, result) in res:
820                 self.failUnlessEqual(success, True)
821                 results.append(result)
822             results.sort()
823             expected_results = ["done %d" % i for i in range(20)]
824             expected_results.sort()
825             self.failUnlessEqual(results, expected_results)
826             self.failUnless(self.peak_simultaneous <= 10)
827             self.failUnlessEqual(len(self.calls), 20)
828             for i in range(20):
829                 self.failUnless( (i, str(i)) in self.calls)
830             def _good(res):
831                 self.fail("should have failed, not got %s" % (res,))
832             def _err(f):
833                 f.trap(ValueError)
834                 self.failUnless("bad_job 21" in str(f))
835             d2.addCallbacks(_good, _err)
836             return d2
837         d.addCallback(_most_done)
838         def _all_done(res):
839             self.failUnlessEqual(self.simultaneous, 0)
840             self.failUnless(self.peak_simultaneous <= 10)
841             self.failUnlessEqual(len(self.calls), 20)
842             for i in range(20):
843                 self.failUnless( (i, str(i)) in self.calls)
844         d.addCallback(_all_done)
845         return d
846
847 class TimeFormat(unittest.TestCase):
848     def test_epoch(self):
849         return self._help_test_epoch()
850
851     def test_epoch_in_London(self):
852         # Europe/London is a particularly troublesome timezone.  Nowadays, its
853         # offset from GMT is 0.  But in 1970, its offset from GMT was 1.
854         # (Apparently in 1970 Britain had redefined standard time to be GMT+1
855         # and stayed in standard time all year round, whereas today
856         # Europe/London standard time is GMT and Europe/London Daylight
857         # Savings Time is GMT+1.)  The current implementation of
858         # time_format.iso_utc_time_to_localseconds() breaks if the timezone is
859         # Europe/London.  (As soon as this unit test is done then I'll change
860         # that implementation to something that works even in this case...)
861         origtz = os.environ.get('TZ')
862         os.environ['TZ'] = "Europe/London"
863         if hasattr(time, 'tzset'):
864             time.tzset()
865         try:
866             return self._help_test_epoch()
867         finally:
868             if origtz is None:
869                 del os.environ['TZ']
870             else:
871                 os.environ['TZ'] = origtz
872             if hasattr(time, 'tzset'):
873                 time.tzset()
874
875     def _help_test_epoch(self):
876         origtzname = time.tzname
877         s = time_format.iso_utc_time_to_seconds("1970-01-01T00:00:01")
878         self.failUnlessEqual(s, 1.0)
879         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01")
880         self.failUnlessEqual(s, 1.0)
881         s = time_format.iso_utc_time_to_seconds("1970-01-01 00:00:01")
882         self.failUnlessEqual(s, 1.0)
883
884         self.failUnlessEqual(time_format.iso_utc(1.0), "1970-01-01_00:00:01")
885         self.failUnlessEqual(time_format.iso_utc(1.0, sep=" "),
886                              "1970-01-01 00:00:01")
887
888         now = time.time()
889         isostr = time_format.iso_utc(now)
890         timestamp = time_format.iso_utc_time_to_seconds(isostr)
891         self.failUnlessEqual(int(timestamp), int(now))
892
893         def my_time():
894             return 1.0
895         self.failUnlessEqual(time_format.iso_utc(t=my_time),
896                              "1970-01-01_00:00:01")
897         e = self.failUnlessRaises(ValueError,
898                                   time_format.iso_utc_time_to_seconds,
899                                   "invalid timestring")
900         self.failUnless("not a complete ISO8601 timestamp" in str(e))
901         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01.500")
902         self.failUnlessEqual(s, 1.5)
903
904         # Look for daylight-savings-related errors.
905         thatmomentinmarch = time_format.iso_utc_time_to_seconds("2009-03-20 21:49:02.226536")
906         self.failUnlessEqual(thatmomentinmarch, 1237585742.226536)
907         self.failUnlessEqual(origtzname, time.tzname)
908
909     def test_iso_utc(self):
910         when = 1266760143.7841301
911         out = time_format.iso_utc_date(when)
912         self.failUnlessEqual(out, "2010-02-21")
913         out = time_format.iso_utc_date(t=lambda: when)
914         self.failUnlessEqual(out, "2010-02-21")
915         out = time_format.iso_utc(when)
916         self.failUnlessEqual(out, "2010-02-21_13:49:03.784130")
917         out = time_format.iso_utc(when, sep="-")
918         self.failUnlessEqual(out, "2010-02-21-13:49:03.784130")
919
920     def test_parse_duration(self):
921         p = time_format.parse_duration
922         DAY = 24*60*60
923         self.failUnlessEqual(p("1 day"), DAY)
924         self.failUnlessEqual(p("2 days"), 2*DAY)
925         self.failUnlessEqual(p("3 months"), 3*31*DAY)
926         self.failUnlessEqual(p("4 mo"), 4*31*DAY)
927         self.failUnlessEqual(p("5 years"), 5*365*DAY)
928         e = self.failUnlessRaises(ValueError, p, "123")
929         self.failUnlessIn("no unit (like day, month, or year) in '123'",
930                           str(e))
931
932     def test_parse_date(self):
933         self.failUnlessEqual(time_format.parse_date("2010-02-21"), 1266710400)
934
935 class CacheDir(unittest.TestCase):
936     def test_basic(self):
937         basedir = "test_util/CacheDir/test_basic"
938
939         def _failIfExists(name):
940             absfn = os.path.join(basedir, name)
941             self.failIf(os.path.exists(absfn),
942                         "%s exists but it shouldn't" % absfn)
943
944         def _failUnlessExists(name):
945             absfn = os.path.join(basedir, name)
946             self.failUnless(os.path.exists(absfn),
947                             "%s doesn't exist but it should" % absfn)
948
949         cdm = cachedir.CacheDirectoryManager(basedir)
950         a = cdm.get_file("a")
951         b = cdm.get_file("b")
952         c = cdm.get_file("c")
953         f = open(a.get_filename(), "wb"); f.write("hi"); f.close(); del f
954         f = open(b.get_filename(), "wb"); f.write("hi"); f.close(); del f
955         f = open(c.get_filename(), "wb"); f.write("hi"); f.close(); del f
956
957         _failUnlessExists("a")
958         _failUnlessExists("b")
959         _failUnlessExists("c")
960
961         cdm.check()
962
963         _failUnlessExists("a")
964         _failUnlessExists("b")
965         _failUnlessExists("c")
966
967         del a
968         # this file won't be deleted yet, because it isn't old enough
969         cdm.check()
970         _failUnlessExists("a")
971         _failUnlessExists("b")
972         _failUnlessExists("c")
973
974         # we change the definition of "old" to make everything old
975         cdm.old = -10
976
977         cdm.check()
978         _failIfExists("a")
979         _failUnlessExists("b")
980         _failUnlessExists("c")
981
982         cdm.old = 60*60
983
984         del b
985
986         cdm.check()
987         _failIfExists("a")
988         _failUnlessExists("b")
989         _failUnlessExists("c")
990
991         b2 = cdm.get_file("b")
992
993         cdm.check()
994         _failIfExists("a")
995         _failUnlessExists("b")
996         _failUnlessExists("c")
997         del b2
998
999 ctr = [0]
1000 class EqButNotIs:
1001     def __init__(self, x):
1002         self.x = x
1003         self.hash = ctr[0]
1004         ctr[0] += 1
1005     def __repr__(self):
1006         return "<%s %s>" % (self.__class__.__name__, self.x,)
1007     def __hash__(self):
1008         return self.hash
1009     def __le__(self, other):
1010         return self.x <= other
1011     def __lt__(self, other):
1012         return self.x < other
1013     def __ge__(self, other):
1014         return self.x >= other
1015     def __gt__(self, other):
1016         return self.x > other
1017     def __ne__(self, other):
1018         return self.x != other
1019     def __eq__(self, other):
1020         return self.x == other
1021
1022 class DictUtil(unittest.TestCase):
1023     def _help_test_empty_dict(self, klass):
1024         d1 = klass()
1025         d2 = klass({})
1026
1027         self.failUnless(d1 == d2, "d1: %r, d2: %r" % (d1, d2,))
1028         self.failUnless(len(d1) == 0)
1029         self.failUnless(len(d2) == 0)
1030
1031     def _help_test_nonempty_dict(self, klass):
1032         d1 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1033         d2 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1034
1035         self.failUnless(d1 == d2)
1036         self.failUnless(len(d1) == 3, "%s, %s" % (len(d1), d1,))
1037         self.failUnless(len(d2) == 3)
1038
1039     def _help_test_eq_but_notis(self, klass):
1040         d = klass({'a': 3, 'b': EqButNotIs(3), 'c': 3})
1041         d.pop('b')
1042
1043         d.clear()
1044         d['a'] = 3
1045         d['b'] = EqButNotIs(3)
1046         d['c'] = 3
1047         d.pop('b')
1048
1049         d.clear()
1050         d['b'] = EqButNotIs(3)
1051         d['a'] = 3
1052         d['c'] = 3
1053         d.pop('b')
1054
1055         d.clear()
1056         d['a'] = EqButNotIs(3)
1057         d['c'] = 3
1058         d['a'] = 3
1059
1060         d.clear()
1061         fake3 = EqButNotIs(3)
1062         fake7 = EqButNotIs(7)
1063         d[fake3] = fake7
1064         d[3] = 7
1065         d[3] = 8
1066         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1067         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1068         # The real 7 should have been ejected by the d[3] = 8.
1069         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1070         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1071         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1072         d[fake3] = 8
1073
1074         d.clear()
1075         d[3] = 7
1076         fake3 = EqButNotIs(3)
1077         fake7 = EqButNotIs(7)
1078         d[fake3] = fake7
1079         d[3] = 8
1080         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1081         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1082         # The real 7 should have been ejected by the d[3] = 8.
1083         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1084         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1085         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1086         d[fake3] = 8
1087
1088     def test_all(self):
1089         self._help_test_eq_but_notis(dictutil.UtilDict)
1090         self._help_test_eq_but_notis(dictutil.NumDict)
1091         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1092         self._help_test_nonempty_dict(dictutil.UtilDict)
1093         self._help_test_nonempty_dict(dictutil.NumDict)
1094         self._help_test_nonempty_dict(dictutil.ValueOrderedDict)
1095         self._help_test_eq_but_notis(dictutil.UtilDict)
1096         self._help_test_eq_but_notis(dictutil.NumDict)
1097         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1098
1099     def test_dict_of_sets(self):
1100         ds = dictutil.DictOfSets()
1101         ds.add(1, "a")
1102         ds.add(2, "b")
1103         ds.add(2, "b")
1104         ds.add(2, "c")
1105         self.failUnlessEqual(ds[1], set(["a"]))
1106         self.failUnlessEqual(ds[2], set(["b", "c"]))
1107         ds.discard(3, "d") # should not raise an exception
1108         ds.discard(2, "b")
1109         self.failUnlessEqual(ds[2], set(["c"]))
1110         ds.discard(2, "c")
1111         self.failIf(2 in ds)
1112
1113         ds.union(1, ["a", "e"])
1114         ds.union(3, ["f"])
1115         self.failUnlessEqual(ds[1], set(["a","e"]))
1116         self.failUnlessEqual(ds[3], set(["f"]))
1117         ds2 = dictutil.DictOfSets()
1118         ds2.add(3, "f")
1119         ds2.add(3, "g")
1120         ds2.add(4, "h")
1121         ds.update(ds2)
1122         self.failUnlessEqual(ds[1], set(["a","e"]))
1123         self.failUnlessEqual(ds[3], set(["f", "g"]))
1124         self.failUnlessEqual(ds[4], set(["h"]))
1125
1126     def test_move(self):
1127         d1 = {1: "a", 2: "b"}
1128         d2 = {2: "c", 3: "d"}
1129         dictutil.move(1, d1, d2)
1130         self.failUnlessEqual(d1, {2: "b"})
1131         self.failUnlessEqual(d2, {1: "a", 2: "c", 3: "d"})
1132
1133         d1 = {1: "a", 2: "b"}
1134         d2 = {2: "c", 3: "d"}
1135         dictutil.move(2, d1, d2)
1136         self.failUnlessEqual(d1, {1: "a"})
1137         self.failUnlessEqual(d2, {2: "b", 3: "d"})
1138
1139         d1 = {1: "a", 2: "b"}
1140         d2 = {2: "c", 3: "d"}
1141         self.failUnlessRaises(KeyError, dictutil.move, 5, d1, d2, strict=True)
1142
1143     def test_subtract(self):
1144         d1 = {1: "a", 2: "b"}
1145         d2 = {2: "c", 3: "d"}
1146         d3 = dictutil.subtract(d1, d2)
1147         self.failUnlessEqual(d3, {1: "a"})
1148
1149         d1 = {1: "a", 2: "b"}
1150         d2 = {2: "c"}
1151         d3 = dictutil.subtract(d1, d2)
1152         self.failUnlessEqual(d3, {1: "a"})
1153
1154     def test_utildict(self):
1155         d = dictutil.UtilDict({1: "a", 2: "b"})
1156         d.del_if_present(1)
1157         d.del_if_present(3)
1158         self.failUnlessEqual(d, {2: "b"})
1159         def eq(a, b):
1160             return a == b
1161         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1162
1163         d = dictutil.UtilDict({1: "b", 2: "a"})
1164         self.failUnlessEqual(d.items_sorted_by_value(),
1165                              [(2, "a"), (1, "b")])
1166         self.failUnlessEqual(d.items_sorted_by_key(),
1167                              [(1, "b"), (2, "a")])
1168         self.failUnlessEqual(repr(d), "{1: 'b', 2: 'a'}")
1169         self.failUnless(1 in d)
1170
1171         d2 = dictutil.UtilDict({3: "c", 4: "d"})
1172         self.failUnless(d != d2)
1173         self.failUnless(d2 > d)
1174         self.failUnless(d2 >= d)
1175         self.failUnless(d <= d2)
1176         self.failUnless(d < d2)
1177         self.failUnlessEqual(d[1], "b")
1178         self.failUnlessEqual(sorted(list([k for k in d])), [1,2])
1179
1180         d3 = d.copy()
1181         self.failUnlessEqual(d, d3)
1182         self.failUnless(isinstance(d3, dictutil.UtilDict))
1183
1184         d4 = d.fromkeys([3,4], "e")
1185         self.failUnlessEqual(d4, {3: "e", 4: "e"})
1186
1187         self.failUnlessEqual(d.get(1), "b")
1188         self.failUnlessEqual(d.get(3), None)
1189         self.failUnlessEqual(d.get(3, "default"), "default")
1190         self.failUnlessEqual(sorted(list(d.items())),
1191                              [(1, "b"), (2, "a")])
1192         self.failUnlessEqual(sorted(list(d.iteritems())),
1193                              [(1, "b"), (2, "a")])
1194         self.failUnlessEqual(sorted(d.keys()), [1, 2])
1195         self.failUnlessEqual(sorted(d.values()), ["a", "b"])
1196         x = d.setdefault(1, "new")
1197         self.failUnlessEqual(x, "b")
1198         self.failUnlessEqual(d[1], "b")
1199         x = d.setdefault(3, "new")
1200         self.failUnlessEqual(x, "new")
1201         self.failUnlessEqual(d[3], "new")
1202         del d[3]
1203
1204         x = d.popitem()
1205         self.failUnless(x in [(1, "b"), (2, "a")])
1206         x = d.popitem()
1207         self.failUnless(x in [(1, "b"), (2, "a")])
1208         self.failUnlessRaises(KeyError, d.popitem)
1209
1210     def test_numdict(self):
1211         d = dictutil.NumDict({"a": 1, "b": 2})
1212
1213         d.add_num("a", 10, 5)
1214         d.add_num("c", 20, 5)
1215         d.add_num("d", 30)
1216         self.failUnlessEqual(d, {"a": 11, "b": 2, "c": 25, "d": 30})
1217
1218         d.subtract_num("a", 10)
1219         d.subtract_num("e", 10)
1220         d.subtract_num("f", 10, 15)
1221         self.failUnlessEqual(d, {"a": 1, "b": 2, "c": 25, "d": 30,
1222                                  "e": -10, "f": 5})
1223
1224         self.failUnlessEqual(d.sum(), sum([1, 2, 25, 30, -10, 5]))
1225
1226         d = dictutil.NumDict()
1227         d.inc("a")
1228         d.inc("a")
1229         d.inc("b", 5)
1230         self.failUnlessEqual(d, {"a": 2, "b": 6})
1231         d.dec("a")
1232         d.dec("c")
1233         d.dec("d", 5)
1234         self.failUnlessEqual(d, {"a": 1, "b": 6, "c": -1, "d": 4})
1235         self.failUnlessEqual(d.items_sorted_by_key(),
1236                              [("a", 1), ("b", 6), ("c", -1), ("d", 4)])
1237         self.failUnlessEqual(d.items_sorted_by_value(),
1238                              [("c", -1), ("a", 1), ("d", 4), ("b", 6)])
1239         self.failUnlessEqual(d.item_with_largest_value(), ("b", 6))
1240
1241         d = dictutil.NumDict({"a": 1, "b": 2})
1242         self.failUnlessEqual(repr(d), "{'a': 1, 'b': 2}")
1243         self.failUnless("a" in d)
1244
1245         d2 = dictutil.NumDict({"c": 3, "d": 4})
1246         self.failUnless(d != d2)
1247         self.failUnless(d2 > d)
1248         self.failUnless(d2 >= d)
1249         self.failUnless(d <= d2)
1250         self.failUnless(d < d2)
1251         self.failUnlessEqual(d["a"], 1)
1252         self.failUnlessEqual(sorted(list([k for k in d])), ["a","b"])
1253         def eq(a, b):
1254             return a == b
1255         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1256
1257         d3 = d.copy()
1258         self.failUnlessEqual(d, d3)
1259         self.failUnless(isinstance(d3, dictutil.NumDict))
1260
1261         d4 = d.fromkeys(["a","b"], 5)
1262         self.failUnlessEqual(d4, {"a": 5, "b": 5})
1263
1264         self.failUnlessEqual(d.get("a"), 1)
1265         self.failUnlessEqual(d.get("c"), 0)
1266         self.failUnlessEqual(d.get("c", 5), 5)
1267         self.failUnlessEqual(sorted(list(d.items())),
1268                              [("a", 1), ("b", 2)])
1269         self.failUnlessEqual(sorted(list(d.iteritems())),
1270                              [("a", 1), ("b", 2)])
1271         self.failUnlessEqual(sorted(d.keys()), ["a", "b"])
1272         self.failUnlessEqual(sorted(d.values()), [1, 2])
1273         self.failUnless(d.has_key("a"))
1274         self.failIf(d.has_key("c"))
1275
1276         x = d.setdefault("c", 3)
1277         self.failUnlessEqual(x, 3)
1278         self.failUnlessEqual(d["c"], 3)
1279         x = d.setdefault("c", 5)
1280         self.failUnlessEqual(x, 3)
1281         self.failUnlessEqual(d["c"], 3)
1282         del d["c"]
1283
1284         x = d.popitem()
1285         self.failUnless(x in [("a", 1), ("b", 2)])
1286         x = d.popitem()
1287         self.failUnless(x in [("a", 1), ("b", 2)])
1288         self.failUnlessRaises(KeyError, d.popitem)
1289
1290         d.update({"c": 3})
1291         d.update({"c": 4, "d": 5})
1292         self.failUnlessEqual(d, {"c": 4, "d": 5})
1293
1294     def test_del_if_present(self):
1295         d = {1: "a", 2: "b"}
1296         dictutil.del_if_present(d, 1)
1297         dictutil.del_if_present(d, 3)
1298         self.failUnlessEqual(d, {2: "b"})
1299
1300     def test_valueordereddict(self):
1301         d = dictutil.ValueOrderedDict()
1302         d["a"] = 3
1303         d["b"] = 2
1304         d["c"] = 1
1305
1306         self.failUnlessEqual(d, {"a": 3, "b": 2, "c": 1})
1307         self.failUnlessEqual(d.items(), [("c", 1), ("b", 2), ("a", 3)])
1308         self.failUnlessEqual(d.values(), [1, 2, 3])
1309         self.failUnlessEqual(d.keys(), ["c", "b", "a"])
1310         self.failUnlessEqual(repr(d), "<ValueOrderedDict {c: 1, b: 2, a: 3}>")
1311         def eq(a, b):
1312             return a == b
1313         self.failIf(d == {"a": 4})
1314         self.failUnless(d != {"a": 4})
1315
1316         x = d.setdefault("d", 0)
1317         self.failUnlessEqual(x, 0)
1318         self.failUnlessEqual(d["d"], 0)
1319         x = d.setdefault("d", -1)
1320         self.failUnlessEqual(x, 0)
1321         self.failUnlessEqual(d["d"], 0)
1322
1323         x = d.remove("e", "default", False)
1324         self.failUnlessEqual(x, "default")
1325         self.failUnlessRaises(KeyError, d.remove, "e", "default", True)
1326         x = d.remove("d", 5)
1327         self.failUnlessEqual(x, 0)
1328
1329         x = d.__getitem__("c")
1330         self.failUnlessEqual(x, 1)
1331         x = d.__getitem__("e", "default", False)
1332         self.failUnlessEqual(x, "default")
1333         self.failUnlessRaises(KeyError, d.__getitem__, "e", "default", True)
1334
1335         self.failUnlessEqual(d.popitem(), ("c", 1))
1336         self.failUnlessEqual(d.popitem(), ("b", 2))
1337         self.failUnlessEqual(d.popitem(), ("a", 3))
1338         self.failUnlessRaises(KeyError, d.popitem)
1339
1340         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1341         x = d.pop("d", "default", False)
1342         self.failUnlessEqual(x, "default")
1343         self.failUnlessRaises(KeyError, d.pop, "d", "default", True)
1344         x = d.pop("b")
1345         self.failUnlessEqual(x, 2)
1346         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1347
1348         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1349         x = d.pop_from_list(1) # pop the second item, b/2
1350         self.failUnlessEqual(x, "b")
1351         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1352
1353     def test_auxdict(self):
1354         d = dictutil.AuxValueDict()
1355         # we put the serialized form in the auxdata
1356         d.set_with_aux("key", ("filecap", "metadata"), "serialized")
1357
1358         self.failUnlessEqual(d.keys(), ["key"])
1359         self.failUnlessEqual(d["key"], ("filecap", "metadata"))
1360         self.failUnlessEqual(d.get_aux("key"), "serialized")
1361         def _get_missing(key):
1362             return d[key]
1363         self.failUnlessRaises(KeyError, _get_missing, "nonkey")
1364         self.failUnlessEqual(d.get("nonkey"), None)
1365         self.failUnlessEqual(d.get("nonkey", "nonvalue"), "nonvalue")
1366         self.failUnlessEqual(d.get_aux("nonkey"), None)
1367         self.failUnlessEqual(d.get_aux("nonkey", "nonvalue"), "nonvalue")
1368
1369         d["key"] = ("filecap2", "metadata2")
1370         self.failUnlessEqual(d["key"], ("filecap2", "metadata2"))
1371         self.failUnlessEqual(d.get_aux("key"), None)
1372
1373         d.set_with_aux("key2", "value2", "aux2")
1374         self.failUnlessEqual(sorted(d.keys()), ["key", "key2"])
1375         del d["key2"]
1376         self.failUnlessEqual(d.keys(), ["key"])
1377         self.failIf("key2" in d)
1378         self.failUnlessRaises(KeyError, _get_missing, "key2")
1379         self.failUnlessEqual(d.get("key2"), None)
1380         self.failUnlessEqual(d.get_aux("key2"), None)
1381         d["key2"] = "newvalue2"
1382         self.failUnlessEqual(d.get("key2"), "newvalue2")
1383         self.failUnlessEqual(d.get_aux("key2"), None)
1384
1385         d = dictutil.AuxValueDict({1:2,3:4})
1386         self.failUnlessEqual(sorted(d.keys()), [1,3])
1387         self.failUnlessEqual(d[1], 2)
1388         self.failUnlessEqual(d.get_aux(1), None)
1389
1390         d = dictutil.AuxValueDict([ (1,2), (3,4) ])
1391         self.failUnlessEqual(sorted(d.keys()), [1,3])
1392         self.failUnlessEqual(d[1], 2)
1393         self.failUnlessEqual(d.get_aux(1), None)
1394
1395         d = dictutil.AuxValueDict(one=1, two=2)
1396         self.failUnlessEqual(sorted(d.keys()), ["one","two"])
1397         self.failUnlessEqual(d["one"], 1)
1398         self.failUnlessEqual(d.get_aux("one"), None)
1399
1400 class Pipeline(unittest.TestCase):
1401     def pause(self, *args, **kwargs):
1402         d = defer.Deferred()
1403         self.calls.append( (d, args, kwargs) )
1404         return d
1405
1406     def failUnlessCallsAre(self, expected):
1407         #print self.calls
1408         #print expected
1409         self.failUnlessEqual(len(self.calls), len(expected), self.calls)
1410         for i,c in enumerate(self.calls):
1411             self.failUnlessEqual(c[1:], expected[i], str(i))
1412
1413     def test_basic(self):
1414         self.calls = []
1415         finished = []
1416         p = pipeline.Pipeline(100)
1417
1418         d = p.flush() # fires immediately
1419         d.addCallbacks(finished.append, log.err)
1420         self.failUnlessEqual(len(finished), 1)
1421         finished = []
1422
1423         d = p.add(10, self.pause, "one")
1424         # the call should start right away, and our return Deferred should
1425         # fire right away
1426         d.addCallbacks(finished.append, log.err)
1427         self.failUnlessEqual(len(finished), 1)
1428         self.failUnlessEqual(finished[0], None)
1429         self.failUnlessCallsAre([ ( ("one",) , {} ) ])
1430         self.failUnlessEqual(p.gauge, 10)
1431
1432         # pipeline: [one]
1433
1434         finished = []
1435         d = p.add(20, self.pause, "two", kw=2)
1436         # pipeline: [one, two]
1437
1438         # the call and the Deferred should fire right away
1439         d.addCallbacks(finished.append, log.err)
1440         self.failUnlessEqual(len(finished), 1)
1441         self.failUnlessEqual(finished[0], None)
1442         self.failUnlessCallsAre([ ( ("one",) , {} ),
1443                                   ( ("two",) , {"kw": 2} ),
1444                                   ])
1445         self.failUnlessEqual(p.gauge, 30)
1446
1447         self.calls[0][0].callback("one-result")
1448         # pipeline: [two]
1449         self.failUnlessEqual(p.gauge, 20)
1450
1451         finished = []
1452         d = p.add(90, self.pause, "three", "posarg1")
1453         # pipeline: [two, three]
1454         flushed = []
1455         fd = p.flush()
1456         fd.addCallbacks(flushed.append, log.err)
1457         self.failUnlessEqual(flushed, [])
1458
1459         # the call will be made right away, but the return Deferred will not,
1460         # because the pipeline is now full.
1461         d.addCallbacks(finished.append, log.err)
1462         self.failUnlessEqual(len(finished), 0)
1463         self.failUnlessCallsAre([ ( ("one",) , {} ),
1464                                   ( ("two",) , {"kw": 2} ),
1465                                   ( ("three", "posarg1"), {} ),
1466                                   ])
1467         self.failUnlessEqual(p.gauge, 110)
1468
1469         self.failUnlessRaises(pipeline.SingleFileError, p.add, 10, self.pause)
1470
1471         # retiring either call will unblock the pipeline, causing the #3
1472         # Deferred to fire
1473         self.calls[2][0].callback("three-result")
1474         # pipeline: [two]
1475
1476         self.failUnlessEqual(len(finished), 1)
1477         self.failUnlessEqual(finished[0], None)
1478         self.failUnlessEqual(flushed, [])
1479
1480         # retiring call#2 will finally allow the flush() Deferred to fire
1481         self.calls[1][0].callback("two-result")
1482         self.failUnlessEqual(len(flushed), 1)
1483
1484     def test_errors(self):
1485         self.calls = []
1486         p = pipeline.Pipeline(100)
1487
1488         d1 = p.add(200, self.pause, "one")
1489         d2 = p.flush()
1490
1491         finished = []
1492         d1.addBoth(finished.append)
1493         self.failUnlessEqual(finished, [])
1494
1495         flushed = []
1496         d2.addBoth(flushed.append)
1497         self.failUnlessEqual(flushed, [])
1498
1499         self.calls[0][0].errback(ValueError("oops"))
1500
1501         self.failUnlessEqual(len(finished), 1)
1502         f = finished[0]
1503         self.failUnless(isinstance(f, Failure))
1504         self.failUnless(f.check(pipeline.PipelineError))
1505         self.failUnlessIn("PipelineError", str(f.value))
1506         self.failUnlessIn("ValueError", str(f.value))
1507         r = repr(f.value)
1508         self.failUnless("ValueError" in r, r)
1509         f2 = f.value.error
1510         self.failUnless(f2.check(ValueError))
1511
1512         self.failUnlessEqual(len(flushed), 1)
1513         f = flushed[0]
1514         self.failUnless(isinstance(f, Failure))
1515         self.failUnless(f.check(pipeline.PipelineError))
1516         f2 = f.value.error
1517         self.failUnless(f2.check(ValueError))
1518
1519         # now that the pipeline is in the failed state, any new calls will
1520         # fail immediately
1521
1522         d3 = p.add(20, self.pause, "two")
1523
1524         finished = []
1525         d3.addBoth(finished.append)
1526         self.failUnlessEqual(len(finished), 1)
1527         f = finished[0]
1528         self.failUnless(isinstance(f, Failure))
1529         self.failUnless(f.check(pipeline.PipelineError))
1530         r = repr(f.value)
1531         self.failUnless("ValueError" in r, r)
1532         f2 = f.value.error
1533         self.failUnless(f2.check(ValueError))
1534
1535         d4 = p.flush()
1536         flushed = []
1537         d4.addBoth(flushed.append)
1538         self.failUnlessEqual(len(flushed), 1)
1539         f = flushed[0]
1540         self.failUnless(isinstance(f, Failure))
1541         self.failUnless(f.check(pipeline.PipelineError))
1542         f2 = f.value.error
1543         self.failUnless(f2.check(ValueError))
1544
1545     def test_errors2(self):
1546         self.calls = []
1547         p = pipeline.Pipeline(100)
1548
1549         d1 = p.add(10, self.pause, "one")
1550         d2 = p.add(20, self.pause, "two")
1551         d3 = p.add(30, self.pause, "three")
1552         d4 = p.flush()
1553
1554         # one call fails, then the second one succeeds: make sure
1555         # ExpandableDeferredList tolerates the second one
1556
1557         flushed = []
1558         d4.addBoth(flushed.append)
1559         self.failUnlessEqual(flushed, [])
1560
1561         self.calls[0][0].errback(ValueError("oops"))
1562         self.failUnlessEqual(len(flushed), 1)
1563         f = flushed[0]
1564         self.failUnless(isinstance(f, Failure))
1565         self.failUnless(f.check(pipeline.PipelineError))
1566         f2 = f.value.error
1567         self.failUnless(f2.check(ValueError))
1568
1569         self.calls[1][0].callback("two-result")
1570         self.calls[2][0].errback(ValueError("three-error"))
1571
1572         del d1,d2,d3,d4
1573
1574 class SampleError(Exception):
1575     pass
1576
1577 class Log(unittest.TestCase):
1578     def test_err(self):
1579         if not hasattr(self, "flushLoggedErrors"):
1580             # without flushLoggedErrors, we can't get rid of the
1581             # twisted.log.err that tahoe_log records, so we can't keep this
1582             # test from [ERROR]ing
1583             raise unittest.SkipTest("needs flushLoggedErrors from Twisted-2.5.0")
1584         try:
1585             raise SampleError("simple sample")
1586         except:
1587             f = Failure()
1588         tahoe_log.err(format="intentional sample error",
1589                       failure=f, level=tahoe_log.OPERATIONAL, umid="wO9UoQ")
1590         self.flushLoggedErrors(SampleError)
1591
1592
1593 class SimpleSpans:
1594     # this is a simple+inefficient form of util.spans.Spans . We compare the
1595     # behavior of this reference model against the real (efficient) form.
1596
1597     def __init__(self, _span_or_start=None, length=None):
1598         self._have = set()
1599         if length is not None:
1600             for i in range(_span_or_start, _span_or_start+length):
1601                 self._have.add(i)
1602         elif _span_or_start:
1603             for (start,length) in _span_or_start:
1604                 self.add(start, length)
1605
1606     def add(self, start, length):
1607         for i in range(start, start+length):
1608             self._have.add(i)
1609         return self
1610
1611     def remove(self, start, length):
1612         for i in range(start, start+length):
1613             self._have.discard(i)
1614         return self
1615
1616     def each(self):
1617         return sorted(self._have)
1618
1619     def __iter__(self):
1620         items = sorted(self._have)
1621         prevstart = None
1622         prevend = None
1623         for i in items:
1624             if prevstart is None:
1625                 prevstart = prevend = i
1626                 continue
1627             if i == prevend+1:
1628                 prevend = i
1629                 continue
1630             yield (prevstart, prevend-prevstart+1)
1631             prevstart = prevend = i
1632         if prevstart is not None:
1633             yield (prevstart, prevend-prevstart+1)
1634
1635     def __nonzero__(self): # this gets us bool()
1636         return self.len()
1637
1638     def len(self):
1639         return len(self._have)
1640
1641     def __add__(self, other):
1642         s = self.__class__(self)
1643         for (start, length) in other:
1644             s.add(start, length)
1645         return s
1646
1647     def __sub__(self, other):
1648         s = self.__class__(self)
1649         for (start, length) in other:
1650             s.remove(start, length)
1651         return s
1652
1653     def __iadd__(self, other):
1654         for (start, length) in other:
1655             self.add(start, length)
1656         return self
1657
1658     def __isub__(self, other):
1659         for (start, length) in other:
1660             self.remove(start, length)
1661         return self
1662
1663     def __and__(self, other):
1664         s = self.__class__()
1665         for i in other.each():
1666             if i in self._have:
1667                 s.add(i, 1)
1668         return s
1669
1670     def __contains__(self, (start,length)):
1671         for i in range(start, start+length):
1672             if i not in self._have:
1673                 return False
1674         return True
1675
1676 class ByteSpans(unittest.TestCase):
1677     def test_basic(self):
1678         s = Spans()
1679         self.failUnlessEqual(list(s), [])
1680         self.failIf(s)
1681         self.failIf((0,1) in s)
1682         self.failUnlessEqual(s.len(), 0)
1683
1684         s1 = Spans(3, 4) # 3,4,5,6
1685         self._check1(s1)
1686
1687         s1 = Spans(3L, 4L) # 3,4,5,6
1688         self._check1(s1)
1689
1690         s2 = Spans(s1)
1691         self._check1(s2)
1692
1693         s2.add(10,2) # 10,11
1694         self._check1(s1)
1695         self.failUnless((10,1) in s2)
1696         self.failIf((10,1) in s1)
1697         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
1698         self.failUnlessEqual(s2.len(), 6)
1699
1700         s2.add(15,2).add(20,2)
1701         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
1702         self.failUnlessEqual(s2.len(), 10)
1703
1704         s2.remove(4,3).remove(15,1)
1705         self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
1706         self.failUnlessEqual(s2.len(), 6)
1707
1708         s1 = SimpleSpans(3, 4) # 3 4 5 6
1709         s2 = SimpleSpans(5, 4) # 5 6 7 8
1710         i = s1 & s2
1711         self.failUnlessEqual(list(i.each()), [5, 6])
1712
1713     def _check1(self, s):
1714         self.failUnlessEqual(list(s), [(3,4)])
1715         self.failUnless(s)
1716         self.failUnlessEqual(s.len(), 4)
1717         self.failIf((0,1) in s)
1718         self.failUnless((3,4) in s)
1719         self.failUnless((3,1) in s)
1720         self.failUnless((5,2) in s)
1721         self.failUnless((6,1) in s)
1722         self.failIf((6,2) in s)
1723         self.failIf((7,1) in s)
1724         self.failUnlessEqual(list(s.each()), [3,4,5,6])
1725
1726     def test_large(self):
1727         s = Spans(4, 2**65) # don't do this with a SimpleSpans
1728         self.failUnlessEqual(list(s), [(4, 2**65)])
1729         self.failUnless(s)
1730         self.failUnlessEqual(s.len(), 2**65)
1731         self.failIf((0,1) in s)
1732         self.failUnless((4,2) in s)
1733         self.failUnless((2**65,2) in s)
1734
1735     def test_math(self):
1736         s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
1737         s2 = Spans(5, 3) # 5,6,7
1738         s3 = Spans(8, 4) # 8,9,10,11
1739
1740         s = s1 - s2
1741         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1742         s = s1 - s3
1743         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1744         s = s2 - s3
1745         self.failUnlessEqual(list(s.each()), [5,6,7])
1746         s = s1 & s2
1747         self.failUnlessEqual(list(s.each()), [5,6,7])
1748         s = s2 & s1
1749         self.failUnlessEqual(list(s.each()), [5,6,7])
1750         s = s1 & s3
1751         self.failUnlessEqual(list(s.each()), [8,9])
1752         s = s3 & s1
1753         self.failUnlessEqual(list(s.each()), [8,9])
1754         s = s2 & s3
1755         self.failUnlessEqual(list(s.each()), [])
1756         s = s3 & s2
1757         self.failUnlessEqual(list(s.each()), [])
1758         s = Spans() & s3
1759         self.failUnlessEqual(list(s.each()), [])
1760         s = s3 & Spans()
1761         self.failUnlessEqual(list(s.each()), [])
1762
1763         s = s1 + s2
1764         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1765         s = s1 + s3
1766         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1767         s = s2 + s3
1768         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1769
1770         s = Spans(s1)
1771         s -= s2
1772         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1773         s = Spans(s1)
1774         s -= s3
1775         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1776         s = Spans(s2)
1777         s -= s3
1778         self.failUnlessEqual(list(s.each()), [5,6,7])
1779
1780         s = Spans(s1)
1781         s += s2
1782         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1783         s = Spans(s1)
1784         s += s3
1785         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1786         s = Spans(s2)
1787         s += s3
1788         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1789
1790     def test_random(self):
1791         # attempt to increase coverage of corner cases by comparing behavior
1792         # of a simple-but-slow model implementation against the
1793         # complex-but-fast actual implementation, in a large number of random
1794         # operations
1795         S1 = SimpleSpans
1796         S2 = Spans
1797         s1 = S1(); s2 = S2()
1798         seed = ""
1799         def _create(subseed):
1800             ns1 = S1(); ns2 = S2()
1801             for i in range(10):
1802                 what = _hash(subseed+str(i)).hexdigest()
1803                 start = int(what[2:4], 16)
1804                 length = max(1,int(what[5:6], 16))
1805                 ns1.add(start, length); ns2.add(start, length)
1806             return ns1, ns2
1807
1808         #print
1809         for i in range(1000):
1810             what = _hash(seed+str(i)).hexdigest()
1811             op = what[0]
1812             subop = what[1]
1813             start = int(what[2:4], 16)
1814             length = max(1,int(what[5:6], 16))
1815             #print what
1816             if op in "0":
1817                 if subop in "01234":
1818                     s1 = S1(); s2 = S2()
1819                 elif subop in "5678":
1820                     s1 = S1(start, length); s2 = S2(start, length)
1821                 else:
1822                     s1 = S1(s1); s2 = S2(s2)
1823                 #print "s2 = %s" % s2.dump()
1824             elif op in "123":
1825                 #print "s2.add(%d,%d)" % (start, length)
1826                 s1.add(start, length); s2.add(start, length)
1827             elif op in "456":
1828                 #print "s2.remove(%d,%d)" % (start, length)
1829                 s1.remove(start, length); s2.remove(start, length)
1830             elif op in "78":
1831                 ns1, ns2 = _create(what[7:11])
1832                 #print "s2 + %s" % ns2.dump()
1833                 s1 = s1 + ns1; s2 = s2 + ns2
1834             elif op in "9a":
1835                 ns1, ns2 = _create(what[7:11])
1836                 #print "%s - %s" % (s2.dump(), ns2.dump())
1837                 s1 = s1 - ns1; s2 = s2 - ns2
1838             elif op in "bc":
1839                 ns1, ns2 = _create(what[7:11])
1840                 #print "s2 += %s" % ns2.dump()
1841                 s1 += ns1; s2 += ns2
1842             elif op in "de":
1843                 ns1, ns2 = _create(what[7:11])
1844                 #print "%s -= %s" % (s2.dump(), ns2.dump())
1845                 s1 -= ns1; s2 -= ns2
1846             else:
1847                 ns1, ns2 = _create(what[7:11])
1848                 #print "%s &= %s" % (s2.dump(), ns2.dump())
1849                 s1 = s1 & ns1; s2 = s2 & ns2
1850             #print "s2 now %s" % s2.dump()
1851             self.failUnlessEqual(list(s1.each()), list(s2.each()))
1852             self.failUnlessEqual(s1.len(), s2.len())
1853             self.failUnlessEqual(bool(s1), bool(s2))
1854             self.failUnlessEqual(list(s1), list(s2))
1855             for j in range(10):
1856                 what = _hash(what[12:14]+str(j)).hexdigest()
1857                 start = int(what[2:4], 16)
1858                 length = max(1, int(what[5:6], 16))
1859                 span = (start, length)
1860                 self.failUnlessEqual(bool(span in s1), bool(span in s2))
1861
1862
1863     # s()
1864     # s(start,length)
1865     # s(s0)
1866     # s.add(start,length) : returns s
1867     # s.remove(start,length)
1868     # s.each() -> list of byte offsets, mostly for testing
1869     # list(s) -> list of (start,length) tuples, one per span
1870     # (start,length) in s -> True if (start..start+length-1) are all members
1871     #  NOT equivalent to x in list(s)
1872     # s.len() -> number of bytes, for testing, bool(), and accounting/limiting
1873     # bool(s)  (__nonzeron__)
1874     # s = s1+s2, s1-s2, +=s1, -=s1
1875
1876     def test_overlap(self):
1877         for a in range(20):
1878             for b in range(10):
1879                 for c in range(20):
1880                     for d in range(10):
1881                         self._test_overlap(a,b,c,d)
1882
1883     def _test_overlap(self, a, b, c, d):
1884         s1 = set(range(a,a+b))
1885         s2 = set(range(c,c+d))
1886         #print "---"
1887         #self._show_overlap(s1, "1")
1888         #self._show_overlap(s2, "2")
1889         o = overlap(a,b,c,d)
1890         expected = s1.intersection(s2)
1891         if not expected:
1892             self.failUnlessEqual(o, None)
1893         else:
1894             start,length = o
1895             so = set(range(start,start+length))
1896             #self._show(so, "o")
1897             self.failUnlessEqual(so, expected)
1898
1899     def _show_overlap(self, s, c):
1900         import sys
1901         out = sys.stdout
1902         if s:
1903             for i in range(max(s)):
1904                 if i in s:
1905                     out.write(c)
1906                 else:
1907                     out.write(" ")
1908         out.write("\n")
1909
1910 def extend(s, start, length, fill):
1911     if len(s) >= start+length:
1912         return s
1913     assert len(fill) == 1
1914     return s + fill*(start+length-len(s))
1915
1916 def replace(s, start, data):
1917     assert len(s) >= start+len(data)
1918     return s[:start] + data + s[start+len(data):]
1919
1920 class SimpleDataSpans:
1921     def __init__(self, other=None):
1922         self.missing = "" # "1" where missing, "0" where found
1923         self.data = ""
1924         if other:
1925             for (start, data) in other.get_chunks():
1926                 self.add(start, data)
1927
1928     def __nonzero__(self): # this gets us bool()
1929         return self.len()
1930     def len(self):
1931         return len(self.missing.replace("1", ""))
1932     def _dump(self):
1933         return [i for (i,c) in enumerate(self.missing) if c == "0"]
1934     def _have(self, start, length):
1935         m = self.missing[start:start+length]
1936         if not m or len(m)<length or int(m):
1937             return False
1938         return True
1939     def get_chunks(self):
1940         for i in self._dump():
1941             yield (i, self.data[i])
1942     def get_spans(self):
1943         return SimpleSpans([(start,len(data))
1944                             for (start,data) in self.get_chunks()])
1945     def get(self, start, length):
1946         if self._have(start, length):
1947             return self.data[start:start+length]
1948         return None
1949     def pop(self, start, length):
1950         data = self.get(start, length)
1951         if data:
1952             self.remove(start, length)
1953         return data
1954     def remove(self, start, length):
1955         self.missing = replace(extend(self.missing, start, length, "1"),
1956                                start, "1"*length)
1957     def add(self, start, data):
1958         self.missing = replace(extend(self.missing, start, len(data), "1"),
1959                                start, "0"*len(data))
1960         self.data = replace(extend(self.data, start, len(data), " "),
1961                             start, data)
1962
1963
1964 class StringSpans(unittest.TestCase):
1965     def do_basic(self, klass):
1966         ds = klass()
1967         self.failUnlessEqual(ds.len(), 0)
1968         self.failUnlessEqual(list(ds._dump()), [])
1969         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 0)
1970         s = ds.get_spans()
1971         self.failUnlessEqual(ds.get(0, 4), None)
1972         self.failUnlessEqual(ds.pop(0, 4), None)
1973         ds.remove(0, 4)
1974
1975         ds.add(2, "four")
1976         self.failUnlessEqual(ds.len(), 4)
1977         self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
1978         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1979         s = ds.get_spans()
1980         self.failUnless((2,2) in s)
1981         self.failUnlessEqual(ds.get(0, 4), None)
1982         self.failUnlessEqual(ds.pop(0, 4), None)
1983         self.failUnlessEqual(ds.get(4, 4), None)
1984
1985         ds2 = klass(ds)
1986         self.failUnlessEqual(ds2.len(), 4)
1987         self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
1988         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 4)
1989         self.failUnlessEqual(ds2.get(0, 4), None)
1990         self.failUnlessEqual(ds2.pop(0, 4), None)
1991         self.failUnlessEqual(ds2.pop(2, 3), "fou")
1992         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 1)
1993         self.failUnlessEqual(ds2.get(2, 3), None)
1994         self.failUnlessEqual(ds2.get(5, 1), "r")
1995         self.failUnlessEqual(ds.get(2, 3), "fou")
1996         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1997
1998         ds.add(0, "23")
1999         self.failUnlessEqual(ds.len(), 6)
2000         self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
2001         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 6)
2002         self.failUnlessEqual(ds.get(0, 4), "23fo")
2003         self.failUnlessEqual(ds.pop(0, 4), "23fo")
2004         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 2)
2005         self.failUnlessEqual(ds.get(0, 4), None)
2006         self.failUnlessEqual(ds.pop(0, 4), None)
2007
2008         ds = klass()
2009         ds.add(2, "four")
2010         ds.add(3, "ea")
2011         self.failUnlessEqual(ds.get(2, 4), "fear")
2012
2013         ds = klass()
2014         ds.add(2L, "four")
2015         ds.add(3L, "ea")
2016         self.failUnlessEqual(ds.get(2L, 4L), "fear")
2017
2018
2019     def do_scan(self, klass):
2020         # do a test with gaps and spans of size 1 and 2
2021         #  left=(1,11) * right=(1,11) * gapsize=(1,2)
2022         # 111, 112, 121, 122, 211, 212, 221, 222
2023         #    211
2024         #      121
2025         #         112
2026         #            212
2027         #               222
2028         #                   221
2029         #                      111
2030         #                        122
2031         #  11 1  1 11 11  11  1 1  111
2032         # 0123456789012345678901234567
2033         # abcdefghijklmnopqrstuvwxyz-=
2034         pieces = [(1, "bc"),
2035                   (4, "e"),
2036                   (7, "h"),
2037                   (9, "jk"),
2038                   (12, "mn"),
2039                   (16, "qr"),
2040                   (20, "u"),
2041                   (22, "w"),
2042                   (25, "z-="),
2043                   ]
2044         p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
2045         S = "abcdefghijklmnopqrstuvwxyz-="
2046         # TODO: when adding data, add capital letters, to make sure we aren't
2047         # just leaving the old data in place
2048         l = len(S)
2049         def base():
2050             ds = klass()
2051             for start, data in pieces:
2052                 ds.add(start, data)
2053             return ds
2054         def dump(s):
2055             p = set(s._dump())
2056             d = "".join([((i not in p) and " " or S[i]) for i in range(l)])
2057             assert len(d) == l
2058             return d
2059         DEBUG = False
2060         for start in range(0, l):
2061             for end in range(start+1, l):
2062                 # add [start-end) to the baseline
2063                 which = "%d-%d" % (start, end-1)
2064                 p_added = set(range(start, end))
2065                 b = base()
2066                 if DEBUG:
2067                     print
2068                     print dump(b), which
2069                     add = klass(); add.add(start, S[start:end])
2070                     print dump(add)
2071                 b.add(start, S[start:end])
2072                 if DEBUG:
2073                     print dump(b)
2074                 # check that the new span is there
2075                 d = b.get(start, end-start)
2076                 self.failUnlessEqual(d, S[start:end], which)
2077                 # check that all the original pieces are still there
2078                 for t_start, t_data in pieces:
2079                     t_len = len(t_data)
2080                     self.failUnlessEqual(b.get(t_start, t_len),
2081                                          S[t_start:t_start+t_len],
2082                                          "%s %d+%d" % (which, t_start, t_len))
2083                 # check that a lot of subspans are mostly correct
2084                 for t_start in range(l):
2085                     for t_len in range(1,4):
2086                         d = b.get(t_start, t_len)
2087                         if d is not None:
2088                             which2 = "%s+(%d-%d)" % (which, t_start,
2089                                                      t_start+t_len-1)
2090                             self.failUnlessEqual(d, S[t_start:t_start+t_len],
2091                                                  which2)
2092                         # check that removing a subspan gives the right value
2093                         b2 = klass(b)
2094                         b2.remove(t_start, t_len)
2095                         removed = set(range(t_start, t_start+t_len))
2096                         for i in range(l):
2097                             exp = (((i in p_elements) or (i in p_added))
2098                                    and (i not in removed))
2099                             which2 = "%s-(%d-%d)" % (which, t_start,
2100                                                      t_start+t_len-1)
2101                             self.failUnlessEqual(bool(b2.get(i, 1)), exp,
2102                                                  which2+" %d" % i)
2103
2104     def test_test(self):
2105         self.do_basic(SimpleDataSpans)
2106         self.do_scan(SimpleDataSpans)
2107
2108     def test_basic(self):
2109         self.do_basic(DataSpans)
2110         self.do_scan(DataSpans)
2111
2112     def test_random(self):
2113         # attempt to increase coverage of corner cases by comparing behavior
2114         # of a simple-but-slow model implementation against the
2115         # complex-but-fast actual implementation, in a large number of random
2116         # operations
2117         S1 = SimpleDataSpans
2118         S2 = DataSpans
2119         s1 = S1(); s2 = S2()
2120         seed = ""
2121         def _randstr(length, seed):
2122             created = 0
2123             pieces = []
2124             while created < length:
2125                 piece = _hash(seed + str(created)).hexdigest()
2126                 pieces.append(piece)
2127                 created += len(piece)
2128             return "".join(pieces)[:length]
2129         def _create(subseed):
2130             ns1 = S1(); ns2 = S2()
2131             for i in range(10):
2132                 what = _hash(subseed+str(i)).hexdigest()
2133                 start = int(what[2:4], 16)
2134                 length = max(1,int(what[5:6], 16))
2135                 ns1.add(start, _randstr(length, what[7:9]));
2136                 ns2.add(start, _randstr(length, what[7:9]))
2137             return ns1, ns2
2138
2139         #print
2140         for i in range(1000):
2141             what = _hash(seed+str(i)).hexdigest()
2142             op = what[0]
2143             subop = what[1]
2144             start = int(what[2:4], 16)
2145             length = max(1,int(what[5:6], 16))
2146             #print what
2147             if op in "0":
2148                 if subop in "0123456":
2149                     s1 = S1(); s2 = S2()
2150                 else:
2151                     s1, s2 = _create(what[7:11])
2152                 #print "s2 = %s" % list(s2._dump())
2153             elif op in "123456":
2154                 #print "s2.add(%d,%d)" % (start, length)
2155                 s1.add(start, _randstr(length, what[7:9]));
2156                 s2.add(start, _randstr(length, what[7:9]))
2157             elif op in "789abc":
2158                 #print "s2.remove(%d,%d)" % (start, length)
2159                 s1.remove(start, length); s2.remove(start, length)
2160             else:
2161                 #print "s2.pop(%d,%d)" % (start, length)
2162                 d1 = s1.pop(start, length); d2 = s2.pop(start, length)
2163                 self.failUnlessEqual(d1, d2)
2164             #print "s1 now %s" % list(s1._dump())
2165             #print "s2 now %s" % list(s2._dump())
2166             self.failUnlessEqual(s1.len(), s2.len())
2167             self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
2168             for j in range(100):
2169                 what = _hash(what[12:14]+str(j)).hexdigest()
2170                 start = int(what[2:4], 16)
2171                 length = max(1, int(what[5:6], 16))
2172                 d1 = s1.get(start, length); d2 = s2.get(start, length)
2173                 self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))