]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_util.py
Add a byte-spans utility class, like perl's Set::IntSpan for .newsrc files.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_util.py
1
2 def foo(): pass # keep the line number constant
3
4 import os, time, sys
5 from StringIO import StringIO
6 from twisted.trial import unittest
7 from twisted.internet import defer, reactor
8 from twisted.python.failure import Failure
9 from twisted.python import log
10 from hashlib import md5
11
12 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
13 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
14 from allmydata.util import limiter, time_format, pollmixin, cachedir
15 from allmydata.util import statistics, dictutil, pipeline
16 from allmydata.util import log as tahoe_log
17 from allmydata.util.spans import Spans, overlap, DataSpans
18
19 class Base32(unittest.TestCase):
20     def test_b2a_matches_Pythons(self):
21         import base64
22         y = "\x12\x34\x45\x67\x89\x0a\xbc\xde\xf0"
23         x = base64.b32encode(y)
24         while x and x[-1] == '=':
25             x = x[:-1]
26         x = x.lower()
27         self.failUnlessEqual(base32.b2a(y), x)
28     def test_b2a(self):
29         self.failUnlessEqual(base32.b2a("\x12\x34"), "ci2a")
30     def test_b2a_or_none(self):
31         self.failUnlessEqual(base32.b2a_or_none(None), None)
32         self.failUnlessEqual(base32.b2a_or_none("\x12\x34"), "ci2a")
33     def test_a2b(self):
34         self.failUnlessEqual(base32.a2b("ci2a"), "\x12\x34")
35         self.failUnlessRaises(AssertionError, base32.a2b, "b0gus")
36
37 class IDLib(unittest.TestCase):
38     def test_nodeid_b2a(self):
39         self.failUnlessEqual(idlib.nodeid_b2a("\x00"*20), "a"*32)
40
41 class NoArgumentException(Exception):
42     def __init__(self):
43         pass
44
45 class HumanReadable(unittest.TestCase):
46     def test_repr(self):
47         hr = humanreadable.hr
48         self.failUnlessEqual(hr(foo), "<foo() at test_util.py:2>")
49         self.failUnlessEqual(hr(self.test_repr),
50                              "<bound method HumanReadable.test_repr of <allmydata.test.test_util.HumanReadable testMethod=test_repr>>")
51         self.failUnlessEqual(hr(1L), "1")
52         self.failUnlessEqual(hr(10**40),
53                              "100000000000000000...000000000000000000")
54         self.failUnlessEqual(hr(self), "<allmydata.test.test_util.HumanReadable testMethod=test_repr>")
55         self.failUnlessEqual(hr([1,2]), "[1, 2]")
56         self.failUnlessEqual(hr({1:2}), "{1:2}")
57         try:
58             raise ValueError
59         except Exception, e:
60             self.failUnless(
61                 hr(e) == "<ValueError: ()>" # python-2.4
62                 or hr(e) == "ValueError()") # python-2.5
63         try:
64             raise ValueError("oops")
65         except Exception, e:
66             self.failUnless(
67                 hr(e) == "<ValueError: 'oops'>" # python-2.4
68                 or hr(e) == "ValueError('oops',)") # python-2.5
69         try:
70             raise NoArgumentException
71         except Exception, e:
72             self.failUnless(
73                 hr(e) == "<NoArgumentException>" # python-2.4
74                 or hr(e) == "NoArgumentException()") # python-2.5
75
76
77 class MyList(list):
78     pass
79
80 class Math(unittest.TestCase):
81     def test_div_ceil(self):
82         f = mathutil.div_ceil
83         self.failUnlessEqual(f(0, 1), 0)
84         self.failUnlessEqual(f(0, 2), 0)
85         self.failUnlessEqual(f(0, 3), 0)
86         self.failUnlessEqual(f(1, 3), 1)
87         self.failUnlessEqual(f(2, 3), 1)
88         self.failUnlessEqual(f(3, 3), 1)
89         self.failUnlessEqual(f(4, 3), 2)
90         self.failUnlessEqual(f(5, 3), 2)
91         self.failUnlessEqual(f(6, 3), 2)
92         self.failUnlessEqual(f(7, 3), 3)
93
94     def test_next_multiple(self):
95         f = mathutil.next_multiple
96         self.failUnlessEqual(f(5, 1), 5)
97         self.failUnlessEqual(f(5, 2), 6)
98         self.failUnlessEqual(f(5, 3), 6)
99         self.failUnlessEqual(f(5, 4), 8)
100         self.failUnlessEqual(f(5, 5), 5)
101         self.failUnlessEqual(f(5, 6), 6)
102         self.failUnlessEqual(f(32, 1), 32)
103         self.failUnlessEqual(f(32, 2), 32)
104         self.failUnlessEqual(f(32, 3), 33)
105         self.failUnlessEqual(f(32, 4), 32)
106         self.failUnlessEqual(f(32, 5), 35)
107         self.failUnlessEqual(f(32, 6), 36)
108         self.failUnlessEqual(f(32, 7), 35)
109         self.failUnlessEqual(f(32, 8), 32)
110         self.failUnlessEqual(f(32, 9), 36)
111         self.failUnlessEqual(f(32, 10), 40)
112         self.failUnlessEqual(f(32, 11), 33)
113         self.failUnlessEqual(f(32, 12), 36)
114         self.failUnlessEqual(f(32, 13), 39)
115         self.failUnlessEqual(f(32, 14), 42)
116         self.failUnlessEqual(f(32, 15), 45)
117         self.failUnlessEqual(f(32, 16), 32)
118         self.failUnlessEqual(f(32, 17), 34)
119         self.failUnlessEqual(f(32, 18), 36)
120         self.failUnlessEqual(f(32, 589), 589)
121
122     def test_pad_size(self):
123         f = mathutil.pad_size
124         self.failUnlessEqual(f(0, 4), 0)
125         self.failUnlessEqual(f(1, 4), 3)
126         self.failUnlessEqual(f(2, 4), 2)
127         self.failUnlessEqual(f(3, 4), 1)
128         self.failUnlessEqual(f(4, 4), 0)
129         self.failUnlessEqual(f(5, 4), 3)
130
131     def test_is_power_of_k(self):
132         f = mathutil.is_power_of_k
133         for i in range(1, 100):
134             if i in (1, 2, 4, 8, 16, 32, 64):
135                 self.failUnless(f(i, 2), "but %d *is* a power of 2" % i)
136             else:
137                 self.failIf(f(i, 2), "but %d is *not* a power of 2" % i)
138         for i in range(1, 100):
139             if i in (1, 3, 9, 27, 81):
140                 self.failUnless(f(i, 3), "but %d *is* a power of 3" % i)
141             else:
142                 self.failIf(f(i, 3), "but %d is *not* a power of 3" % i)
143
144     def test_next_power_of_k(self):
145         f = mathutil.next_power_of_k
146         self.failUnlessEqual(f(0,2), 1)
147         self.failUnlessEqual(f(1,2), 1)
148         self.failUnlessEqual(f(2,2), 2)
149         self.failUnlessEqual(f(3,2), 4)
150         self.failUnlessEqual(f(4,2), 4)
151         for i in range(5, 8): self.failUnlessEqual(f(i,2), 8, "%d" % i)
152         for i in range(9, 16): self.failUnlessEqual(f(i,2), 16, "%d" % i)
153         for i in range(17, 32): self.failUnlessEqual(f(i,2), 32, "%d" % i)
154         for i in range(33, 64): self.failUnlessEqual(f(i,2), 64, "%d" % i)
155         for i in range(65, 100): self.failUnlessEqual(f(i,2), 128, "%d" % i)
156
157         self.failUnlessEqual(f(0,3), 1)
158         self.failUnlessEqual(f(1,3), 1)
159         self.failUnlessEqual(f(2,3), 3)
160         self.failUnlessEqual(f(3,3), 3)
161         for i in range(4, 9): self.failUnlessEqual(f(i,3), 9, "%d" % i)
162         for i in range(10, 27): self.failUnlessEqual(f(i,3), 27, "%d" % i)
163         for i in range(28, 81): self.failUnlessEqual(f(i,3), 81, "%d" % i)
164         for i in range(82, 200): self.failUnlessEqual(f(i,3), 243, "%d" % i)
165
166     def test_ave(self):
167         f = mathutil.ave
168         self.failUnlessEqual(f([1,2,3]), 2)
169         self.failUnlessEqual(f([0,0,0,4]), 1)
170         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
171
172     def test_round_sigfigs(self):
173         f = mathutil.round_sigfigs
174         self.failUnlessEqual(f(22.0/3, 4), 7.3330000000000002)
175
176 class Statistics(unittest.TestCase):
177     def should_assert(self, msg, func, *args, **kwargs):
178         try:
179             func(*args, **kwargs)
180             self.fail(msg)
181         except AssertionError:
182             pass
183
184     def failUnlessListEqual(self, a, b, msg = None):
185         self.failUnlessEqual(len(a), len(b))
186         for i in range(len(a)):
187             self.failUnlessEqual(a[i], b[i], msg)
188
189     def failUnlessListAlmostEqual(self, a, b, places = 7, msg = None):
190         self.failUnlessEqual(len(a), len(b))
191         for i in range(len(a)):
192             self.failUnlessAlmostEqual(a[i], b[i], places, msg)
193
194     def test_binomial_coeff(self):
195         f = statistics.binomial_coeff
196         self.failUnlessEqual(f(20, 0), 1)
197         self.failUnlessEqual(f(20, 1), 20)
198         self.failUnlessEqual(f(20, 2), 190)
199         self.failUnlessEqual(f(20, 8), f(20, 12))
200         self.should_assert("Should assert if n < k", f, 2, 3)
201
202     def test_binomial_distribution_pmf(self):
203         f = statistics.binomial_distribution_pmf
204
205         pmf_comp = f(2, .1)
206         pmf_stat = [0.81, 0.18, 0.01]
207         self.failUnlessListAlmostEqual(pmf_comp, pmf_stat)
208
209         # Summing across a PMF should give the total probability 1
210         self.failUnlessAlmostEqual(sum(pmf_comp), 1)
211         self.should_assert("Should assert if not 0<=p<=1", f, 1, -1)
212         self.should_assert("Should assert if n < 1", f, 0, .1)
213
214         out = StringIO()
215         statistics.print_pmf(pmf_comp, out=out)
216         lines = out.getvalue().splitlines()
217         self.failUnlessEqual(lines[0], "i=0: 0.81")
218         self.failUnlessEqual(lines[1], "i=1: 0.18")
219         self.failUnlessEqual(lines[2], "i=2: 0.01")
220
221     def test_survival_pmf(self):
222         f = statistics.survival_pmf
223         # Cross-check binomial-distribution method against convolution
224         # method.
225         p_list = [.9999] * 100 + [.99] * 50 + [.8] * 20
226         pmf1 = statistics.survival_pmf_via_conv(p_list)
227         pmf2 = statistics.survival_pmf_via_bd(p_list)
228         self.failUnlessListAlmostEqual(pmf1, pmf2)
229         self.failUnlessTrue(statistics.valid_pmf(pmf1))
230         self.should_assert("Should assert if p_i > 1", f, [1.1]);
231         self.should_assert("Should assert if p_i < 0", f, [-.1]);
232
233     def test_repair_count_pmf(self):
234         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
235         repair_pmf = statistics.repair_count_pmf(survival_pmf, 3)
236         # repair_pmf[0] == sum(survival_pmf[0,1,2,5])
237         # repair_pmf[1] == survival_pmf[4]
238         # repair_pmf[2] = survival_pmf[3]
239         self.failUnlessListAlmostEqual(repair_pmf,
240                                        [0.00001 + 0.00045 + 0.0081 + 0.59049,
241                                         .32805,
242                                         .0729,
243                                         0, 0, 0])
244
245     def test_repair_cost(self):
246         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
247         bwcost = statistics.bandwidth_cost_function
248         cost = statistics.mean_repair_cost(bwcost, 1000,
249                                            survival_pmf, 3, ul_dl_ratio=1.0)
250         self.failUnlessAlmostEqual(cost, 558.90)
251         cost = statistics.mean_repair_cost(bwcost, 1000,
252                                            survival_pmf, 3, ul_dl_ratio=8.0)
253         self.failUnlessAlmostEqual(cost, 1664.55)
254
255         # I haven't manually checked the math beyond here -warner
256         cost = statistics.eternal_repair_cost(bwcost, 1000,
257                                               survival_pmf, 3,
258                                               discount_rate=0, ul_dl_ratio=1.0)
259         self.failUnlessAlmostEqual(cost, 65292.056074766246)
260         cost = statistics.eternal_repair_cost(bwcost, 1000,
261                                               survival_pmf, 3,
262                                               discount_rate=0.05,
263                                               ul_dl_ratio=1.0)
264         self.failUnlessAlmostEqual(cost, 9133.6097158191551)
265
266     def test_convolve(self):
267         f = statistics.convolve
268         v1 = [ 1, 2, 3 ]
269         v2 = [ 4, 5, 6 ]
270         v3 = [ 7, 8 ]
271         v1v2result = [ 4, 13, 28, 27, 18 ]
272         # Convolution is commutative
273         r1 = f(v1, v2)
274         r2 = f(v2, v1)
275         self.failUnlessListEqual(r1, r2, "Convolution should be commutative")
276         self.failUnlessListEqual(r1, v1v2result, "Didn't match known result")
277         # Convolution is associative
278         r1 = f(f(v1, v2), v3)
279         r2 = f(v1, f(v2, v3))
280         self.failUnlessListEqual(r1, r2, "Convolution should be associative")
281         # Convolution is distributive
282         r1 = f(v3, [ a + b for a, b in zip(v1, v2) ])
283         tmp1 = f(v3, v1)
284         tmp2 = f(v3, v2)
285         r2 = [ a + b for a, b in zip(tmp1, tmp2) ]
286         self.failUnlessListEqual(r1, r2, "Convolution should be distributive")
287         # Convolution is scalar multiplication associative
288         tmp1 = f(v1, v2)
289         r1 = [ a * 4 for a in tmp1 ]
290         tmp2 = [ a * 4 for a in v1 ]
291         r2 = f(tmp2, v2)
292         self.failUnlessListEqual(r1, r2, "Convolution should be scalar multiplication associative")
293
294     def test_find_k(self):
295         f = statistics.find_k
296         g = statistics.pr_file_loss
297         plist = [.9] * 10 + [.8] * 10 # N=20
298         t = .0001
299         k = f(plist, t)
300         self.failUnlessEqual(k, 10)
301         self.failUnless(g(plist, k) < t)
302
303     def test_pr_file_loss(self):
304         f = statistics.pr_file_loss
305         plist = [.5] * 10
306         self.failUnlessEqual(f(plist, 3), .0546875)
307
308     def test_pr_backup_file_loss(self):
309         f = statistics.pr_backup_file_loss
310         plist = [.5] * 10
311         self.failUnlessEqual(f(plist, .5, 3), .02734375)
312
313
314 class Asserts(unittest.TestCase):
315     def should_assert(self, func, *args, **kwargs):
316         try:
317             func(*args, **kwargs)
318         except AssertionError, e:
319             return str(e)
320         except Exception, e:
321             self.fail("assert failed with non-AssertionError: %s" % e)
322         self.fail("assert was not caught")
323
324     def should_not_assert(self, func, *args, **kwargs):
325         try:
326             func(*args, **kwargs)
327         except AssertionError, e:
328             self.fail("assertion fired when it should not have: %s" % e)
329         except Exception, e:
330             self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
331         return # we're happy
332
333
334     def test_assert(self):
335         f = assertutil._assert
336         self.should_assert(f)
337         self.should_assert(f, False)
338         self.should_not_assert(f, True)
339
340         m = self.should_assert(f, False, "message")
341         self.failUnlessEqual(m, "'message' <type 'str'>", m)
342         m = self.should_assert(f, False, "message1", othermsg=12)
343         self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
344         m = self.should_assert(f, False, othermsg="message2")
345         self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
346
347     def test_precondition(self):
348         f = assertutil.precondition
349         self.should_assert(f)
350         self.should_assert(f, False)
351         self.should_not_assert(f, True)
352
353         m = self.should_assert(f, False, "message")
354         self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
355         m = self.should_assert(f, False, "message1", othermsg=12)
356         self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
357         m = self.should_assert(f, False, othermsg="message2")
358         self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
359
360     def test_postcondition(self):
361         f = assertutil.postcondition
362         self.should_assert(f)
363         self.should_assert(f, False)
364         self.should_not_assert(f, True)
365
366         m = self.should_assert(f, False, "message")
367         self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
368         m = self.should_assert(f, False, "message1", othermsg=12)
369         self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
370         m = self.should_assert(f, False, othermsg="message2")
371         self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
372
373 class FileUtil(unittest.TestCase):
374     def mkdir(self, basedir, path, mode=0777):
375         fn = os.path.join(basedir, path)
376         fileutil.make_dirs(fn, mode)
377
378     def touch(self, basedir, path, mode=None, data="touch\n"):
379         fn = os.path.join(basedir, path)
380         f = open(fn, "w")
381         f.write(data)
382         f.close()
383         if mode is not None:
384             os.chmod(fn, mode)
385
386     def test_rm_dir(self):
387         basedir = "util/FileUtil/test_rm_dir"
388         fileutil.make_dirs(basedir)
389         # create it again to test idempotency
390         fileutil.make_dirs(basedir)
391         d = os.path.join(basedir, "doomed")
392         self.mkdir(d, "a/b")
393         self.touch(d, "a/b/1.txt")
394         self.touch(d, "a/b/2.txt", 0444)
395         self.touch(d, "a/b/3.txt", 0)
396         self.mkdir(d, "a/c")
397         self.touch(d, "a/c/1.txt")
398         self.touch(d, "a/c/2.txt", 0444)
399         self.touch(d, "a/c/3.txt", 0)
400         os.chmod(os.path.join(d, "a/c"), 0444)
401         self.mkdir(d, "a/d")
402         self.touch(d, "a/d/1.txt")
403         self.touch(d, "a/d/2.txt", 0444)
404         self.touch(d, "a/d/3.txt", 0)
405         os.chmod(os.path.join(d, "a/d"), 0)
406
407         fileutil.rm_dir(d)
408         self.failIf(os.path.exists(d))
409         # remove it again to test idempotency
410         fileutil.rm_dir(d)
411
412     def test_remove_if_possible(self):
413         basedir = "util/FileUtil/test_remove_if_possible"
414         fileutil.make_dirs(basedir)
415         self.touch(basedir, "here")
416         fn = os.path.join(basedir, "here")
417         fileutil.remove_if_possible(fn)
418         self.failIf(os.path.exists(fn))
419         fileutil.remove_if_possible(fn) # should be idempotent
420         fileutil.rm_dir(basedir)
421         fileutil.remove_if_possible(fn) # should survive errors
422
423     def test_open_or_create(self):
424         basedir = "util/FileUtil/test_open_or_create"
425         fileutil.make_dirs(basedir)
426         fn = os.path.join(basedir, "here")
427         f = fileutil.open_or_create(fn)
428         f.write("stuff.")
429         f.close()
430         f = fileutil.open_or_create(fn)
431         f.seek(0, 2)
432         f.write("more.")
433         f.close()
434         f = open(fn, "r")
435         data = f.read()
436         f.close()
437         self.failUnlessEqual(data, "stuff.more.")
438
439     def test_NamedTemporaryDirectory(self):
440         basedir = "util/FileUtil/test_NamedTemporaryDirectory"
441         fileutil.make_dirs(basedir)
442         td = fileutil.NamedTemporaryDirectory(dir=basedir)
443         name = td.name
444         self.failUnless(basedir in name)
445         self.failUnless(basedir in repr(td))
446         self.failUnless(os.path.isdir(name))
447         del td
448         # it is conceivable that we need to force gc here, but I'm not sure
449         self.failIf(os.path.isdir(name))
450
451     def test_rename(self):
452         basedir = "util/FileUtil/test_rename"
453         fileutil.make_dirs(basedir)
454         self.touch(basedir, "here")
455         fn = os.path.join(basedir, "here")
456         fn2 = os.path.join(basedir, "there")
457         fileutil.rename(fn, fn2)
458         self.failIf(os.path.exists(fn))
459         self.failUnless(os.path.exists(fn2))
460
461     def test_du(self):
462         basedir = "util/FileUtil/test_du"
463         fileutil.make_dirs(basedir)
464         d = os.path.join(basedir, "space-consuming")
465         self.mkdir(d, "a/b")
466         self.touch(d, "a/b/1.txt", data="a"*10)
467         self.touch(d, "a/b/2.txt", data="b"*11)
468         self.mkdir(d, "a/c")
469         self.touch(d, "a/c/1.txt", data="c"*12)
470         self.touch(d, "a/c/2.txt", data="d"*13)
471
472         used = fileutil.du(basedir)
473         self.failUnlessEqual(10+11+12+13, used)
474
475     def test_abspath_expanduser_unicode(self):
476         self.failUnlessRaises(AssertionError, fileutil.abspath_expanduser_unicode, "bytestring")
477
478         saved_cwd = os.path.normpath(os.getcwdu())
479         abspath_cwd = fileutil.abspath_expanduser_unicode(u".")
480         self.failUnless(isinstance(saved_cwd, unicode), saved_cwd)
481         self.failUnless(isinstance(abspath_cwd, unicode), abspath_cwd)
482         self.failUnlessEqual(abspath_cwd, saved_cwd)
483
484         # adapted from <http://svn.python.org/view/python/branches/release26-maint/Lib/test/test_posixpath.py?view=markup&pathrev=78279#test_abspath>
485
486         self.failUnlessIn(u"foo", fileutil.abspath_expanduser_unicode(u"foo"))
487         self.failIfIn(u"~", fileutil.abspath_expanduser_unicode(u"~"))
488
489         cwds = ['cwd']
490         try:
491             cwds.append(u'\xe7w\xf0'.encode(sys.getfilesystemencoding()
492                                             or 'ascii'))
493         except UnicodeEncodeError:
494             pass # the cwd can't be encoded -- test with ascii cwd only
495
496         for cwd in cwds:
497             try:
498                 os.mkdir(cwd)
499                 os.chdir(cwd)
500                 for upath in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\', u'~'):
501                     uabspath = fileutil.abspath_expanduser_unicode(upath)
502                     self.failUnless(isinstance(uabspath, unicode), uabspath)
503             finally:
504                 os.chdir(saved_cwd)
505
506 class PollMixinTests(unittest.TestCase):
507     def setUp(self):
508         self.pm = pollmixin.PollMixin()
509
510     def test_PollMixin_True(self):
511         d = self.pm.poll(check_f=lambda : True,
512                          pollinterval=0.1)
513         return d
514
515     def test_PollMixin_False_then_True(self):
516         i = iter([False, True])
517         d = self.pm.poll(check_f=i.next,
518                          pollinterval=0.1)
519         return d
520
521     def test_timeout(self):
522         d = self.pm.poll(check_f=lambda: False,
523                          pollinterval=0.01,
524                          timeout=1)
525         def _suc(res):
526             self.fail("poll should have failed, not returned %s" % (res,))
527         def _err(f):
528             f.trap(pollmixin.TimeoutError)
529             return None # success
530         d.addCallbacks(_suc, _err)
531         return d
532
533 class DeferredUtilTests(unittest.TestCase):
534     def test_gather_results(self):
535         d1 = defer.Deferred()
536         d2 = defer.Deferred()
537         res = deferredutil.gatherResults([d1, d2])
538         d1.errback(ValueError("BAD"))
539         def _callb(res):
540             self.fail("Should have errbacked, not resulted in %s" % (res,))
541         def _errb(thef):
542             thef.trap(ValueError)
543         res.addCallbacks(_callb, _errb)
544         return res
545
546     def test_success(self):
547         d1, d2 = defer.Deferred(), defer.Deferred()
548         good = []
549         bad = []
550         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
551         dlss.addCallbacks(good.append, bad.append)
552         d1.callback(1)
553         d2.callback(2)
554         self.failUnlessEqual(good, [[1,2]])
555         self.failUnlessEqual(bad, [])
556
557     def test_failure(self):
558         d1, d2 = defer.Deferred(), defer.Deferred()
559         good = []
560         bad = []
561         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
562         dlss.addCallbacks(good.append, bad.append)
563         d1.addErrback(lambda _ignore: None)
564         d2.addErrback(lambda _ignore: None)
565         d1.callback(1)
566         d2.errback(ValueError())
567         self.failUnlessEqual(good, [])
568         self.failUnlessEqual(len(bad), 1)
569         f = bad[0]
570         self.failUnless(isinstance(f, Failure))
571         self.failUnless(f.check(ValueError))
572
573 class HashUtilTests(unittest.TestCase):
574
575     def test_random_key(self):
576         k = hashutil.random_key()
577         self.failUnlessEqual(len(k), hashutil.KEYLEN)
578
579     def test_sha256d(self):
580         h1 = hashutil.tagged_hash("tag1", "value")
581         h2 = hashutil.tagged_hasher("tag1")
582         h2.update("value")
583         h2a = h2.digest()
584         h2b = h2.digest()
585         self.failUnlessEqual(h1, h2a)
586         self.failUnlessEqual(h2a, h2b)
587
588     def test_sha256d_truncated(self):
589         h1 = hashutil.tagged_hash("tag1", "value", 16)
590         h2 = hashutil.tagged_hasher("tag1", 16)
591         h2.update("value")
592         h2 = h2.digest()
593         self.failUnlessEqual(len(h1), 16)
594         self.failUnlessEqual(len(h2), 16)
595         self.failUnlessEqual(h1, h2)
596
597     def test_chk(self):
598         h1 = hashutil.convergence_hash(3, 10, 1000, "data", "secret")
599         h2 = hashutil.convergence_hasher(3, 10, 1000, "secret")
600         h2.update("data")
601         h2 = h2.digest()
602         self.failUnlessEqual(h1, h2)
603
604     def test_hashers(self):
605         h1 = hashutil.block_hash("foo")
606         h2 = hashutil.block_hasher()
607         h2.update("foo")
608         self.failUnlessEqual(h1, h2.digest())
609
610         h1 = hashutil.uri_extension_hash("foo")
611         h2 = hashutil.uri_extension_hasher()
612         h2.update("foo")
613         self.failUnlessEqual(h1, h2.digest())
614
615         h1 = hashutil.plaintext_hash("foo")
616         h2 = hashutil.plaintext_hasher()
617         h2.update("foo")
618         self.failUnlessEqual(h1, h2.digest())
619
620         h1 = hashutil.crypttext_hash("foo")
621         h2 = hashutil.crypttext_hasher()
622         h2.update("foo")
623         self.failUnlessEqual(h1, h2.digest())
624
625         h1 = hashutil.crypttext_segment_hash("foo")
626         h2 = hashutil.crypttext_segment_hasher()
627         h2.update("foo")
628         self.failUnlessEqual(h1, h2.digest())
629
630         h1 = hashutil.plaintext_segment_hash("foo")
631         h2 = hashutil.plaintext_segment_hasher()
632         h2.update("foo")
633         self.failUnlessEqual(h1, h2.digest())
634
635     def test_constant_time_compare(self):
636         self.failUnless(hashutil.constant_time_compare("a", "a"))
637         self.failUnless(hashutil.constant_time_compare("ab", "ab"))
638         self.failIf(hashutil.constant_time_compare("a", "b"))
639         self.failIf(hashutil.constant_time_compare("a", "aa"))
640
641     def _testknown(self, hashf, expected_a, *args):
642         got = hashf(*args)
643         got_a = base32.b2a(got)
644         self.failUnlessEqual(got_a, expected_a)
645
646     def test_known_answers(self):
647         # assert backwards compatibility
648         self._testknown(hashutil.storage_index_hash, "qb5igbhcc5esa6lwqorsy7e6am", "")
649         self._testknown(hashutil.block_hash, "msjr5bh4evuh7fa3zw7uovixfbvlnstr5b65mrerwfnvjxig2jvq", "")
650         self._testknown(hashutil.uri_extension_hash, "wthsu45q7zewac2mnivoaa4ulh5xvbzdmsbuyztq2a5fzxdrnkka", "")
651         self._testknown(hashutil.plaintext_hash, "5lz5hwz3qj3af7n6e3arblw7xzutvnd3p3fjsngqjcb7utf3x3da", "")
652         self._testknown(hashutil.crypttext_hash, "itdj6e4njtkoiavlrmxkvpreosscssklunhwtvxn6ggho4rkqwga", "")
653         self._testknown(hashutil.crypttext_segment_hash, "aovy5aa7jej6ym5ikgwyoi4pxawnoj3wtaludjz7e2nb5xijb7aa", "")
654         self._testknown(hashutil.plaintext_segment_hash, "4fdgf6qruaisyukhqcmoth4t3li6bkolbxvjy4awwcpprdtva7za", "")
655         self._testknown(hashutil.convergence_hash, "3mo6ni7xweplycin6nowynw2we", 3, 10, 100, "", "converge")
656         self._testknown(hashutil.my_renewal_secret_hash, "ujhr5k5f7ypkp67jkpx6jl4p47pyta7hu5m527cpcgvkafsefm6q", "")
657         self._testknown(hashutil.my_cancel_secret_hash, "rjwzmafe2duixvqy6h47f5wfrokdziry6zhx4smew4cj6iocsfaa", "")
658         self._testknown(hashutil.file_renewal_secret_hash, "hzshk2kf33gzbd5n3a6eszkf6q6o6kixmnag25pniusyaulqjnia", "", "si")
659         self._testknown(hashutil.file_cancel_secret_hash, "bfciwvr6w7wcavsngxzxsxxaszj72dej54n4tu2idzp6b74g255q", "", "si")
660         self._testknown(hashutil.bucket_renewal_secret_hash, "e7imrzgzaoashsncacvy3oysdd2m5yvtooo4gmj4mjlopsazmvuq", "", "\x00"*20)
661         self._testknown(hashutil.bucket_cancel_secret_hash, "dvdujeyxeirj6uux6g7xcf4lvesk632aulwkzjar7srildvtqwma", "", "\x00"*20)
662         self._testknown(hashutil.hmac, "c54ypfi6pevb3nvo6ba42jtglpkry2kbdopqsi7dgrm4r7tw5sra", "tag", "")
663         self._testknown(hashutil.mutable_rwcap_key_hash, "6rvn2iqrghii5n4jbbwwqqsnqu", "iv", "wk")
664         self._testknown(hashutil.ssk_writekey_hash, "ykpgmdbpgbb6yqz5oluw2q26ye", "")
665         self._testknown(hashutil.ssk_write_enabler_master_hash, "izbfbfkoait4dummruol3gy2bnixrrrslgye6ycmkuyujnenzpia", "")
666         self._testknown(hashutil.ssk_write_enabler_hash, "fuu2dvx7g6gqu5x22vfhtyed7p4pd47y5hgxbqzgrlyvxoev62tq", "wk", "\x00"*20)
667         self._testknown(hashutil.ssk_pubkey_fingerprint_hash, "3opzw4hhm2sgncjx224qmt5ipqgagn7h5zivnfzqycvgqgmgz35q", "")
668         self._testknown(hashutil.ssk_readkey_hash, "vugid4as6qbqgeq2xczvvcedai", "")
669         self._testknown(hashutil.ssk_readkey_data_hash, "73wsaldnvdzqaf7v4pzbr2ae5a", "iv", "rk")
670         self._testknown(hashutil.ssk_storage_index_hash, "j7icz6kigb6hxrej3tv4z7ayym", "")
671
672
673 class Abbreviate(unittest.TestCase):
674     def test_time(self):
675         a = abbreviate.abbreviate_time
676         self.failUnlessEqual(a(None), "unknown")
677         self.failUnlessEqual(a(0), "0 seconds")
678         self.failUnlessEqual(a(1), "1 second")
679         self.failUnlessEqual(a(2), "2 seconds")
680         self.failUnlessEqual(a(119), "119 seconds")
681         MIN = 60
682         self.failUnlessEqual(a(2*MIN), "2 minutes")
683         self.failUnlessEqual(a(60*MIN), "60 minutes")
684         self.failUnlessEqual(a(179*MIN), "179 minutes")
685         HOUR = 60*MIN
686         self.failUnlessEqual(a(180*MIN), "3 hours")
687         self.failUnlessEqual(a(4*HOUR), "4 hours")
688         DAY = 24*HOUR
689         MONTH = 30*DAY
690         self.failUnlessEqual(a(2*DAY), "2 days")
691         self.failUnlessEqual(a(2*MONTH), "2 months")
692         YEAR = 365*DAY
693         self.failUnlessEqual(a(5*YEAR), "5 years")
694
695     def test_space(self):
696         tests_si = [(None, "unknown"),
697                     (0, "0 B"),
698                     (1, "1 B"),
699                     (999, "999 B"),
700                     (1000, "1000 B"),
701                     (1023, "1023 B"),
702                     (1024, "1.02 kB"),
703                     (20*1000, "20.00 kB"),
704                     (1024*1024, "1.05 MB"),
705                     (1000*1000, "1.00 MB"),
706                     (1000*1000*1000, "1.00 GB"),
707                     (1000*1000*1000*1000, "1.00 TB"),
708                     (1000*1000*1000*1000*1000, "1.00 PB"),
709                     (1234567890123456, "1.23 PB"),
710                     ]
711         for (x, expected) in tests_si:
712             got = abbreviate.abbreviate_space(x, SI=True)
713             self.failUnlessEqual(got, expected)
714
715         tests_base1024 = [(None, "unknown"),
716                           (0, "0 B"),
717                           (1, "1 B"),
718                           (999, "999 B"),
719                           (1000, "1000 B"),
720                           (1023, "1023 B"),
721                           (1024, "1.00 kiB"),
722                           (20*1024, "20.00 kiB"),
723                           (1000*1000, "976.56 kiB"),
724                           (1024*1024, "1.00 MiB"),
725                           (1024*1024*1024, "1.00 GiB"),
726                           (1024*1024*1024*1024, "1.00 TiB"),
727                           (1000*1000*1000*1000*1000, "909.49 TiB"),
728                           (1024*1024*1024*1024*1024, "1.00 PiB"),
729                           (1234567890123456, "1.10 PiB"),
730                     ]
731         for (x, expected) in tests_base1024:
732             got = abbreviate.abbreviate_space(x, SI=False)
733             self.failUnlessEqual(got, expected)
734
735         self.failUnlessEqual(abbreviate.abbreviate_space_both(1234567),
736                              "(1.23 MB, 1.18 MiB)")
737
738     def test_parse_space(self):
739         p = abbreviate.parse_abbreviated_size
740         self.failUnlessEqual(p(""), None)
741         self.failUnlessEqual(p(None), None)
742         self.failUnlessEqual(p("123"), 123)
743         self.failUnlessEqual(p("123B"), 123)
744         self.failUnlessEqual(p("2K"), 2000)
745         self.failUnlessEqual(p("2kb"), 2000)
746         self.failUnlessEqual(p("2KiB"), 2048)
747         self.failUnlessEqual(p("10MB"), 10*1000*1000)
748         self.failUnlessEqual(p("10MiB"), 10*1024*1024)
749         self.failUnlessEqual(p("5G"), 5*1000*1000*1000)
750         self.failUnlessEqual(p("4GiB"), 4*1024*1024*1024)
751         e = self.failUnlessRaises(ValueError, p, "12 cubits")
752         self.failUnless("12 cubits" in str(e))
753
754 class Limiter(unittest.TestCase):
755     timeout = 480 # This takes longer than 240 seconds on Francois's arm box.
756
757     def job(self, i, foo):
758         self.calls.append( (i, foo) )
759         self.simultaneous += 1
760         self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
761         d = defer.Deferred()
762         def _done():
763             self.simultaneous -= 1
764             d.callback("done %d" % i)
765         reactor.callLater(1.0, _done)
766         return d
767
768     def bad_job(self, i, foo):
769         raise ValueError("bad_job %d" % i)
770
771     def test_limiter(self):
772         self.calls = []
773         self.simultaneous = 0
774         self.peak_simultaneous = 0
775         l = limiter.ConcurrencyLimiter()
776         dl = []
777         for i in range(20):
778             dl.append(l.add(self.job, i, foo=str(i)))
779         d = defer.DeferredList(dl, fireOnOneErrback=True)
780         def _done(res):
781             self.failUnlessEqual(self.simultaneous, 0)
782             self.failUnless(self.peak_simultaneous <= 10)
783             self.failUnlessEqual(len(self.calls), 20)
784             for i in range(20):
785                 self.failUnless( (i, str(i)) in self.calls)
786         d.addCallback(_done)
787         return d
788
789     def test_errors(self):
790         self.calls = []
791         self.simultaneous = 0
792         self.peak_simultaneous = 0
793         l = limiter.ConcurrencyLimiter()
794         dl = []
795         for i in range(20):
796             dl.append(l.add(self.job, i, foo=str(i)))
797         d2 = l.add(self.bad_job, 21, "21")
798         d = defer.DeferredList(dl, fireOnOneErrback=True)
799         def _most_done(res):
800             results = []
801             for (success, result) in res:
802                 self.failUnlessEqual(success, True)
803                 results.append(result)
804             results.sort()
805             expected_results = ["done %d" % i for i in range(20)]
806             expected_results.sort()
807             self.failUnlessEqual(results, expected_results)
808             self.failUnless(self.peak_simultaneous <= 10)
809             self.failUnlessEqual(len(self.calls), 20)
810             for i in range(20):
811                 self.failUnless( (i, str(i)) in self.calls)
812             def _good(res):
813                 self.fail("should have failed, not got %s" % (res,))
814             def _err(f):
815                 f.trap(ValueError)
816                 self.failUnless("bad_job 21" in str(f))
817             d2.addCallbacks(_good, _err)
818             return d2
819         d.addCallback(_most_done)
820         def _all_done(res):
821             self.failUnlessEqual(self.simultaneous, 0)
822             self.failUnless(self.peak_simultaneous <= 10)
823             self.failUnlessEqual(len(self.calls), 20)
824             for i in range(20):
825                 self.failUnless( (i, str(i)) in self.calls)
826         d.addCallback(_all_done)
827         return d
828
829 class TimeFormat(unittest.TestCase):
830     def test_epoch(self):
831         return self._help_test_epoch()
832
833     def test_epoch_in_London(self):
834         # Europe/London is a particularly troublesome timezone.  Nowadays, its
835         # offset from GMT is 0.  But in 1970, its offset from GMT was 1.
836         # (Apparently in 1970 Britain had redefined standard time to be GMT+1
837         # and stayed in standard time all year round, whereas today
838         # Europe/London standard time is GMT and Europe/London Daylight
839         # Savings Time is GMT+1.)  The current implementation of
840         # time_format.iso_utc_time_to_localseconds() breaks if the timezone is
841         # Europe/London.  (As soon as this unit test is done then I'll change
842         # that implementation to something that works even in this case...)
843         origtz = os.environ.get('TZ')
844         os.environ['TZ'] = "Europe/London"
845         if hasattr(time, 'tzset'):
846             time.tzset()
847         try:
848             return self._help_test_epoch()
849         finally:
850             if origtz is None:
851                 del os.environ['TZ']
852             else:
853                 os.environ['TZ'] = origtz
854             if hasattr(time, 'tzset'):
855                 time.tzset()
856
857     def _help_test_epoch(self):
858         origtzname = time.tzname
859         s = time_format.iso_utc_time_to_seconds("1970-01-01T00:00:01")
860         self.failUnlessEqual(s, 1.0)
861         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01")
862         self.failUnlessEqual(s, 1.0)
863         s = time_format.iso_utc_time_to_seconds("1970-01-01 00:00:01")
864         self.failUnlessEqual(s, 1.0)
865
866         self.failUnlessEqual(time_format.iso_utc(1.0), "1970-01-01_00:00:01")
867         self.failUnlessEqual(time_format.iso_utc(1.0, sep=" "),
868                              "1970-01-01 00:00:01")
869
870         now = time.time()
871         isostr = time_format.iso_utc(now)
872         timestamp = time_format.iso_utc_time_to_seconds(isostr)
873         self.failUnlessEqual(int(timestamp), int(now))
874
875         def my_time():
876             return 1.0
877         self.failUnlessEqual(time_format.iso_utc(t=my_time),
878                              "1970-01-01_00:00:01")
879         e = self.failUnlessRaises(ValueError,
880                                   time_format.iso_utc_time_to_seconds,
881                                   "invalid timestring")
882         self.failUnless("not a complete ISO8601 timestamp" in str(e))
883         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01.500")
884         self.failUnlessEqual(s, 1.5)
885
886         # Look for daylight-savings-related errors.
887         thatmomentinmarch = time_format.iso_utc_time_to_seconds("2009-03-20 21:49:02.226536")
888         self.failUnlessEqual(thatmomentinmarch, 1237585742.226536)
889         self.failUnlessEqual(origtzname, time.tzname)
890
891     def test_iso_utc(self):
892         when = 1266760143.7841301
893         out = time_format.iso_utc_date(when)
894         self.failUnlessEqual(out, "2010-02-21")
895         out = time_format.iso_utc_date(t=lambda: when)
896         self.failUnlessEqual(out, "2010-02-21")
897         out = time_format.iso_utc(when)
898         self.failUnlessEqual(out, "2010-02-21_13:49:03.784130")
899         out = time_format.iso_utc(when, sep="-")
900         self.failUnlessEqual(out, "2010-02-21-13:49:03.784130")
901
902     def test_parse_duration(self):
903         p = time_format.parse_duration
904         DAY = 24*60*60
905         self.failUnlessEqual(p("1 day"), DAY)
906         self.failUnlessEqual(p("2 days"), 2*DAY)
907         self.failUnlessEqual(p("3 months"), 3*31*DAY)
908         self.failUnlessEqual(p("4 mo"), 4*31*DAY)
909         self.failUnlessEqual(p("5 years"), 5*365*DAY)
910         e = self.failUnlessRaises(ValueError, p, "123")
911         self.failUnlessIn("no unit (like day, month, or year) in '123'",
912                           str(e))
913
914     def test_parse_date(self):
915         self.failUnlessEqual(time_format.parse_date("2010-02-21"), 1266710400)
916
917 class CacheDir(unittest.TestCase):
918     def test_basic(self):
919         basedir = "test_util/CacheDir/test_basic"
920
921         def _failIfExists(name):
922             absfn = os.path.join(basedir, name)
923             self.failIf(os.path.exists(absfn),
924                         "%s exists but it shouldn't" % absfn)
925
926         def _failUnlessExists(name):
927             absfn = os.path.join(basedir, name)
928             self.failUnless(os.path.exists(absfn),
929                             "%s doesn't exist but it should" % absfn)
930
931         cdm = cachedir.CacheDirectoryManager(basedir)
932         a = cdm.get_file("a")
933         b = cdm.get_file("b")
934         c = cdm.get_file("c")
935         f = open(a.get_filename(), "wb"); f.write("hi"); f.close(); del f
936         f = open(b.get_filename(), "wb"); f.write("hi"); f.close(); del f
937         f = open(c.get_filename(), "wb"); f.write("hi"); f.close(); del f
938
939         _failUnlessExists("a")
940         _failUnlessExists("b")
941         _failUnlessExists("c")
942
943         cdm.check()
944
945         _failUnlessExists("a")
946         _failUnlessExists("b")
947         _failUnlessExists("c")
948
949         del a
950         # this file won't be deleted yet, because it isn't old enough
951         cdm.check()
952         _failUnlessExists("a")
953         _failUnlessExists("b")
954         _failUnlessExists("c")
955
956         # we change the definition of "old" to make everything old
957         cdm.old = -10
958
959         cdm.check()
960         _failIfExists("a")
961         _failUnlessExists("b")
962         _failUnlessExists("c")
963
964         cdm.old = 60*60
965
966         del b
967
968         cdm.check()
969         _failIfExists("a")
970         _failUnlessExists("b")
971         _failUnlessExists("c")
972
973         b2 = cdm.get_file("b")
974
975         cdm.check()
976         _failIfExists("a")
977         _failUnlessExists("b")
978         _failUnlessExists("c")
979         del b2
980
981 ctr = [0]
982 class EqButNotIs:
983     def __init__(self, x):
984         self.x = x
985         self.hash = ctr[0]
986         ctr[0] += 1
987     def __repr__(self):
988         return "<%s %s>" % (self.__class__.__name__, self.x,)
989     def __hash__(self):
990         return self.hash
991     def __le__(self, other):
992         return self.x <= other
993     def __lt__(self, other):
994         return self.x < other
995     def __ge__(self, other):
996         return self.x >= other
997     def __gt__(self, other):
998         return self.x > other
999     def __ne__(self, other):
1000         return self.x != other
1001     def __eq__(self, other):
1002         return self.x == other
1003
1004 class DictUtil(unittest.TestCase):
1005     def _help_test_empty_dict(self, klass):
1006         d1 = klass()
1007         d2 = klass({})
1008
1009         self.failUnless(d1 == d2, "d1: %r, d2: %r" % (d1, d2,))
1010         self.failUnless(len(d1) == 0)
1011         self.failUnless(len(d2) == 0)
1012
1013     def _help_test_nonempty_dict(self, klass):
1014         d1 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1015         d2 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1016
1017         self.failUnless(d1 == d2)
1018         self.failUnless(len(d1) == 3, "%s, %s" % (len(d1), d1,))
1019         self.failUnless(len(d2) == 3)
1020
1021     def _help_test_eq_but_notis(self, klass):
1022         d = klass({'a': 3, 'b': EqButNotIs(3), 'c': 3})
1023         d.pop('b')
1024
1025         d.clear()
1026         d['a'] = 3
1027         d['b'] = EqButNotIs(3)
1028         d['c'] = 3
1029         d.pop('b')
1030
1031         d.clear()
1032         d['b'] = EqButNotIs(3)
1033         d['a'] = 3
1034         d['c'] = 3
1035         d.pop('b')
1036
1037         d.clear()
1038         d['a'] = EqButNotIs(3)
1039         d['c'] = 3
1040         d['a'] = 3
1041
1042         d.clear()
1043         fake3 = EqButNotIs(3)
1044         fake7 = EqButNotIs(7)
1045         d[fake3] = fake7
1046         d[3] = 7
1047         d[3] = 8
1048         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1049         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1050         # The real 7 should have been ejected by the d[3] = 8.
1051         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1052         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1053         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1054         d[fake3] = 8
1055
1056         d.clear()
1057         d[3] = 7
1058         fake3 = EqButNotIs(3)
1059         fake7 = EqButNotIs(7)
1060         d[fake3] = fake7
1061         d[3] = 8
1062         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1063         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1064         # The real 7 should have been ejected by the d[3] = 8.
1065         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1066         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1067         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1068         d[fake3] = 8
1069
1070     def test_all(self):
1071         self._help_test_eq_but_notis(dictutil.UtilDict)
1072         self._help_test_eq_but_notis(dictutil.NumDict)
1073         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1074         self._help_test_nonempty_dict(dictutil.UtilDict)
1075         self._help_test_nonempty_dict(dictutil.NumDict)
1076         self._help_test_nonempty_dict(dictutil.ValueOrderedDict)
1077         self._help_test_eq_but_notis(dictutil.UtilDict)
1078         self._help_test_eq_but_notis(dictutil.NumDict)
1079         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1080
1081     def test_dict_of_sets(self):
1082         ds = dictutil.DictOfSets()
1083         ds.add(1, "a")
1084         ds.add(2, "b")
1085         ds.add(2, "b")
1086         ds.add(2, "c")
1087         self.failUnlessEqual(ds[1], set(["a"]))
1088         self.failUnlessEqual(ds[2], set(["b", "c"]))
1089         ds.discard(3, "d") # should not raise an exception
1090         ds.discard(2, "b")
1091         self.failUnlessEqual(ds[2], set(["c"]))
1092         ds.discard(2, "c")
1093         self.failIf(2 in ds)
1094
1095         ds.union(1, ["a", "e"])
1096         ds.union(3, ["f"])
1097         self.failUnlessEqual(ds[1], set(["a","e"]))
1098         self.failUnlessEqual(ds[3], set(["f"]))
1099         ds2 = dictutil.DictOfSets()
1100         ds2.add(3, "f")
1101         ds2.add(3, "g")
1102         ds2.add(4, "h")
1103         ds.update(ds2)
1104         self.failUnlessEqual(ds[1], set(["a","e"]))
1105         self.failUnlessEqual(ds[3], set(["f", "g"]))
1106         self.failUnlessEqual(ds[4], set(["h"]))
1107
1108     def test_move(self):
1109         d1 = {1: "a", 2: "b"}
1110         d2 = {2: "c", 3: "d"}
1111         dictutil.move(1, d1, d2)
1112         self.failUnlessEqual(d1, {2: "b"})
1113         self.failUnlessEqual(d2, {1: "a", 2: "c", 3: "d"})
1114
1115         d1 = {1: "a", 2: "b"}
1116         d2 = {2: "c", 3: "d"}
1117         dictutil.move(2, d1, d2)
1118         self.failUnlessEqual(d1, {1: "a"})
1119         self.failUnlessEqual(d2, {2: "b", 3: "d"})
1120
1121         d1 = {1: "a", 2: "b"}
1122         d2 = {2: "c", 3: "d"}
1123         self.failUnlessRaises(KeyError, dictutil.move, 5, d1, d2, strict=True)
1124
1125     def test_subtract(self):
1126         d1 = {1: "a", 2: "b"}
1127         d2 = {2: "c", 3: "d"}
1128         d3 = dictutil.subtract(d1, d2)
1129         self.failUnlessEqual(d3, {1: "a"})
1130
1131         d1 = {1: "a", 2: "b"}
1132         d2 = {2: "c"}
1133         d3 = dictutil.subtract(d1, d2)
1134         self.failUnlessEqual(d3, {1: "a"})
1135
1136     def test_utildict(self):
1137         d = dictutil.UtilDict({1: "a", 2: "b"})
1138         d.del_if_present(1)
1139         d.del_if_present(3)
1140         self.failUnlessEqual(d, {2: "b"})
1141         def eq(a, b):
1142             return a == b
1143         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1144
1145         d = dictutil.UtilDict({1: "b", 2: "a"})
1146         self.failUnlessEqual(d.items_sorted_by_value(),
1147                              [(2, "a"), (1, "b")])
1148         self.failUnlessEqual(d.items_sorted_by_key(),
1149                              [(1, "b"), (2, "a")])
1150         self.failUnlessEqual(repr(d), "{1: 'b', 2: 'a'}")
1151         self.failUnless(1 in d)
1152
1153         d2 = dictutil.UtilDict({3: "c", 4: "d"})
1154         self.failUnless(d != d2)
1155         self.failUnless(d2 > d)
1156         self.failUnless(d2 >= d)
1157         self.failUnless(d <= d2)
1158         self.failUnless(d < d2)
1159         self.failUnlessEqual(d[1], "b")
1160         self.failUnlessEqual(sorted(list([k for k in d])), [1,2])
1161
1162         d3 = d.copy()
1163         self.failUnlessEqual(d, d3)
1164         self.failUnless(isinstance(d3, dictutil.UtilDict))
1165
1166         d4 = d.fromkeys([3,4], "e")
1167         self.failUnlessEqual(d4, {3: "e", 4: "e"})
1168
1169         self.failUnlessEqual(d.get(1), "b")
1170         self.failUnlessEqual(d.get(3), None)
1171         self.failUnlessEqual(d.get(3, "default"), "default")
1172         self.failUnlessEqual(sorted(list(d.items())),
1173                              [(1, "b"), (2, "a")])
1174         self.failUnlessEqual(sorted(list(d.iteritems())),
1175                              [(1, "b"), (2, "a")])
1176         self.failUnlessEqual(sorted(d.keys()), [1, 2])
1177         self.failUnlessEqual(sorted(d.values()), ["a", "b"])
1178         x = d.setdefault(1, "new")
1179         self.failUnlessEqual(x, "b")
1180         self.failUnlessEqual(d[1], "b")
1181         x = d.setdefault(3, "new")
1182         self.failUnlessEqual(x, "new")
1183         self.failUnlessEqual(d[3], "new")
1184         del d[3]
1185
1186         x = d.popitem()
1187         self.failUnless(x in [(1, "b"), (2, "a")])
1188         x = d.popitem()
1189         self.failUnless(x in [(1, "b"), (2, "a")])
1190         self.failUnlessRaises(KeyError, d.popitem)
1191
1192     def test_numdict(self):
1193         d = dictutil.NumDict({"a": 1, "b": 2})
1194
1195         d.add_num("a", 10, 5)
1196         d.add_num("c", 20, 5)
1197         d.add_num("d", 30)
1198         self.failUnlessEqual(d, {"a": 11, "b": 2, "c": 25, "d": 30})
1199
1200         d.subtract_num("a", 10)
1201         d.subtract_num("e", 10)
1202         d.subtract_num("f", 10, 15)
1203         self.failUnlessEqual(d, {"a": 1, "b": 2, "c": 25, "d": 30,
1204                                  "e": -10, "f": 5})
1205
1206         self.failUnlessEqual(d.sum(), sum([1, 2, 25, 30, -10, 5]))
1207
1208         d = dictutil.NumDict()
1209         d.inc("a")
1210         d.inc("a")
1211         d.inc("b", 5)
1212         self.failUnlessEqual(d, {"a": 2, "b": 6})
1213         d.dec("a")
1214         d.dec("c")
1215         d.dec("d", 5)
1216         self.failUnlessEqual(d, {"a": 1, "b": 6, "c": -1, "d": 4})
1217         self.failUnlessEqual(d.items_sorted_by_key(),
1218                              [("a", 1), ("b", 6), ("c", -1), ("d", 4)])
1219         self.failUnlessEqual(d.items_sorted_by_value(),
1220                              [("c", -1), ("a", 1), ("d", 4), ("b", 6)])
1221         self.failUnlessEqual(d.item_with_largest_value(), ("b", 6))
1222
1223         d = dictutil.NumDict({"a": 1, "b": 2})
1224         self.failUnlessEqual(repr(d), "{'a': 1, 'b': 2}")
1225         self.failUnless("a" in d)
1226
1227         d2 = dictutil.NumDict({"c": 3, "d": 4})
1228         self.failUnless(d != d2)
1229         self.failUnless(d2 > d)
1230         self.failUnless(d2 >= d)
1231         self.failUnless(d <= d2)
1232         self.failUnless(d < d2)
1233         self.failUnlessEqual(d["a"], 1)
1234         self.failUnlessEqual(sorted(list([k for k in d])), ["a","b"])
1235         def eq(a, b):
1236             return a == b
1237         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1238
1239         d3 = d.copy()
1240         self.failUnlessEqual(d, d3)
1241         self.failUnless(isinstance(d3, dictutil.NumDict))
1242
1243         d4 = d.fromkeys(["a","b"], 5)
1244         self.failUnlessEqual(d4, {"a": 5, "b": 5})
1245
1246         self.failUnlessEqual(d.get("a"), 1)
1247         self.failUnlessEqual(d.get("c"), 0)
1248         self.failUnlessEqual(d.get("c", 5), 5)
1249         self.failUnlessEqual(sorted(list(d.items())),
1250                              [("a", 1), ("b", 2)])
1251         self.failUnlessEqual(sorted(list(d.iteritems())),
1252                              [("a", 1), ("b", 2)])
1253         self.failUnlessEqual(sorted(d.keys()), ["a", "b"])
1254         self.failUnlessEqual(sorted(d.values()), [1, 2])
1255         self.failUnless(d.has_key("a"))
1256         self.failIf(d.has_key("c"))
1257
1258         x = d.setdefault("c", 3)
1259         self.failUnlessEqual(x, 3)
1260         self.failUnlessEqual(d["c"], 3)
1261         x = d.setdefault("c", 5)
1262         self.failUnlessEqual(x, 3)
1263         self.failUnlessEqual(d["c"], 3)
1264         del d["c"]
1265
1266         x = d.popitem()
1267         self.failUnless(x in [("a", 1), ("b", 2)])
1268         x = d.popitem()
1269         self.failUnless(x in [("a", 1), ("b", 2)])
1270         self.failUnlessRaises(KeyError, d.popitem)
1271
1272         d.update({"c": 3})
1273         d.update({"c": 4, "d": 5})
1274         self.failUnlessEqual(d, {"c": 4, "d": 5})
1275
1276     def test_del_if_present(self):
1277         d = {1: "a", 2: "b"}
1278         dictutil.del_if_present(d, 1)
1279         dictutil.del_if_present(d, 3)
1280         self.failUnlessEqual(d, {2: "b"})
1281
1282     def test_valueordereddict(self):
1283         d = dictutil.ValueOrderedDict()
1284         d["a"] = 3
1285         d["b"] = 2
1286         d["c"] = 1
1287
1288         self.failUnlessEqual(d, {"a": 3, "b": 2, "c": 1})
1289         self.failUnlessEqual(d.items(), [("c", 1), ("b", 2), ("a", 3)])
1290         self.failUnlessEqual(d.values(), [1, 2, 3])
1291         self.failUnlessEqual(d.keys(), ["c", "b", "a"])
1292         self.failUnlessEqual(repr(d), "<ValueOrderedDict {c: 1, b: 2, a: 3}>")
1293         def eq(a, b):
1294             return a == b
1295         self.failIf(d == {"a": 4})
1296         self.failUnless(d != {"a": 4})
1297
1298         x = d.setdefault("d", 0)
1299         self.failUnlessEqual(x, 0)
1300         self.failUnlessEqual(d["d"], 0)
1301         x = d.setdefault("d", -1)
1302         self.failUnlessEqual(x, 0)
1303         self.failUnlessEqual(d["d"], 0)
1304
1305         x = d.remove("e", "default", False)
1306         self.failUnlessEqual(x, "default")
1307         self.failUnlessRaises(KeyError, d.remove, "e", "default", True)
1308         x = d.remove("d", 5)
1309         self.failUnlessEqual(x, 0)
1310
1311         x = d.__getitem__("c")
1312         self.failUnlessEqual(x, 1)
1313         x = d.__getitem__("e", "default", False)
1314         self.failUnlessEqual(x, "default")
1315         self.failUnlessRaises(KeyError, d.__getitem__, "e", "default", True)
1316
1317         self.failUnlessEqual(d.popitem(), ("c", 1))
1318         self.failUnlessEqual(d.popitem(), ("b", 2))
1319         self.failUnlessEqual(d.popitem(), ("a", 3))
1320         self.failUnlessRaises(KeyError, d.popitem)
1321
1322         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1323         x = d.pop("d", "default", False)
1324         self.failUnlessEqual(x, "default")
1325         self.failUnlessRaises(KeyError, d.pop, "d", "default", True)
1326         x = d.pop("b")
1327         self.failUnlessEqual(x, 2)
1328         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1329
1330         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1331         x = d.pop_from_list(1) # pop the second item, b/2
1332         self.failUnlessEqual(x, "b")
1333         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1334
1335     def test_auxdict(self):
1336         d = dictutil.AuxValueDict()
1337         # we put the serialized form in the auxdata
1338         d.set_with_aux("key", ("filecap", "metadata"), "serialized")
1339
1340         self.failUnlessEqual(d.keys(), ["key"])
1341         self.failUnlessEqual(d["key"], ("filecap", "metadata"))
1342         self.failUnlessEqual(d.get_aux("key"), "serialized")
1343         def _get_missing(key):
1344             return d[key]
1345         self.failUnlessRaises(KeyError, _get_missing, "nonkey")
1346         self.failUnlessEqual(d.get("nonkey"), None)
1347         self.failUnlessEqual(d.get("nonkey", "nonvalue"), "nonvalue")
1348         self.failUnlessEqual(d.get_aux("nonkey"), None)
1349         self.failUnlessEqual(d.get_aux("nonkey", "nonvalue"), "nonvalue")
1350
1351         d["key"] = ("filecap2", "metadata2")
1352         self.failUnlessEqual(d["key"], ("filecap2", "metadata2"))
1353         self.failUnlessEqual(d.get_aux("key"), None)
1354
1355         d.set_with_aux("key2", "value2", "aux2")
1356         self.failUnlessEqual(sorted(d.keys()), ["key", "key2"])
1357         del d["key2"]
1358         self.failUnlessEqual(d.keys(), ["key"])
1359         self.failIf("key2" in d)
1360         self.failUnlessRaises(KeyError, _get_missing, "key2")
1361         self.failUnlessEqual(d.get("key2"), None)
1362         self.failUnlessEqual(d.get_aux("key2"), None)
1363         d["key2"] = "newvalue2"
1364         self.failUnlessEqual(d.get("key2"), "newvalue2")
1365         self.failUnlessEqual(d.get_aux("key2"), None)
1366
1367         d = dictutil.AuxValueDict({1:2,3:4})
1368         self.failUnlessEqual(sorted(d.keys()), [1,3])
1369         self.failUnlessEqual(d[1], 2)
1370         self.failUnlessEqual(d.get_aux(1), None)
1371
1372         d = dictutil.AuxValueDict([ (1,2), (3,4) ])
1373         self.failUnlessEqual(sorted(d.keys()), [1,3])
1374         self.failUnlessEqual(d[1], 2)
1375         self.failUnlessEqual(d.get_aux(1), None)
1376
1377         d = dictutil.AuxValueDict(one=1, two=2)
1378         self.failUnlessEqual(sorted(d.keys()), ["one","two"])
1379         self.failUnlessEqual(d["one"], 1)
1380         self.failUnlessEqual(d.get_aux("one"), None)
1381
1382 class Pipeline(unittest.TestCase):
1383     def pause(self, *args, **kwargs):
1384         d = defer.Deferred()
1385         self.calls.append( (d, args, kwargs) )
1386         return d
1387
1388     def failUnlessCallsAre(self, expected):
1389         #print self.calls
1390         #print expected
1391         self.failUnlessEqual(len(self.calls), len(expected), self.calls)
1392         for i,c in enumerate(self.calls):
1393             self.failUnlessEqual(c[1:], expected[i], str(i))
1394
1395     def test_basic(self):
1396         self.calls = []
1397         finished = []
1398         p = pipeline.Pipeline(100)
1399
1400         d = p.flush() # fires immediately
1401         d.addCallbacks(finished.append, log.err)
1402         self.failUnlessEqual(len(finished), 1)
1403         finished = []
1404
1405         d = p.add(10, self.pause, "one")
1406         # the call should start right away, and our return Deferred should
1407         # fire right away
1408         d.addCallbacks(finished.append, log.err)
1409         self.failUnlessEqual(len(finished), 1)
1410         self.failUnlessEqual(finished[0], None)
1411         self.failUnlessCallsAre([ ( ("one",) , {} ) ])
1412         self.failUnlessEqual(p.gauge, 10)
1413
1414         # pipeline: [one]
1415
1416         finished = []
1417         d = p.add(20, self.pause, "two", kw=2)
1418         # pipeline: [one, two]
1419
1420         # the call and the Deferred should fire right away
1421         d.addCallbacks(finished.append, log.err)
1422         self.failUnlessEqual(len(finished), 1)
1423         self.failUnlessEqual(finished[0], None)
1424         self.failUnlessCallsAre([ ( ("one",) , {} ),
1425                                   ( ("two",) , {"kw": 2} ),
1426                                   ])
1427         self.failUnlessEqual(p.gauge, 30)
1428
1429         self.calls[0][0].callback("one-result")
1430         # pipeline: [two]
1431         self.failUnlessEqual(p.gauge, 20)
1432
1433         finished = []
1434         d = p.add(90, self.pause, "three", "posarg1")
1435         # pipeline: [two, three]
1436         flushed = []
1437         fd = p.flush()
1438         fd.addCallbacks(flushed.append, log.err)
1439         self.failUnlessEqual(flushed, [])
1440
1441         # the call will be made right away, but the return Deferred will not,
1442         # because the pipeline is now full.
1443         d.addCallbacks(finished.append, log.err)
1444         self.failUnlessEqual(len(finished), 0)
1445         self.failUnlessCallsAre([ ( ("one",) , {} ),
1446                                   ( ("two",) , {"kw": 2} ),
1447                                   ( ("three", "posarg1"), {} ),
1448                                   ])
1449         self.failUnlessEqual(p.gauge, 110)
1450
1451         self.failUnlessRaises(pipeline.SingleFileError, p.add, 10, self.pause)
1452
1453         # retiring either call will unblock the pipeline, causing the #3
1454         # Deferred to fire
1455         self.calls[2][0].callback("three-result")
1456         # pipeline: [two]
1457
1458         self.failUnlessEqual(len(finished), 1)
1459         self.failUnlessEqual(finished[0], None)
1460         self.failUnlessEqual(flushed, [])
1461
1462         # retiring call#2 will finally allow the flush() Deferred to fire
1463         self.calls[1][0].callback("two-result")
1464         self.failUnlessEqual(len(flushed), 1)
1465
1466     def test_errors(self):
1467         self.calls = []
1468         p = pipeline.Pipeline(100)
1469
1470         d1 = p.add(200, self.pause, "one")
1471         d2 = p.flush()
1472
1473         finished = []
1474         d1.addBoth(finished.append)
1475         self.failUnlessEqual(finished, [])
1476
1477         flushed = []
1478         d2.addBoth(flushed.append)
1479         self.failUnlessEqual(flushed, [])
1480
1481         self.calls[0][0].errback(ValueError("oops"))
1482
1483         self.failUnlessEqual(len(finished), 1)
1484         f = finished[0]
1485         self.failUnless(isinstance(f, Failure))
1486         self.failUnless(f.check(pipeline.PipelineError))
1487         self.failUnlessIn("PipelineError", str(f.value))
1488         self.failUnlessIn("ValueError", str(f.value))
1489         r = repr(f.value)
1490         self.failUnless("ValueError" in r, r)
1491         f2 = f.value.error
1492         self.failUnless(f2.check(ValueError))
1493
1494         self.failUnlessEqual(len(flushed), 1)
1495         f = flushed[0]
1496         self.failUnless(isinstance(f, Failure))
1497         self.failUnless(f.check(pipeline.PipelineError))
1498         f2 = f.value.error
1499         self.failUnless(f2.check(ValueError))
1500
1501         # now that the pipeline is in the failed state, any new calls will
1502         # fail immediately
1503
1504         d3 = p.add(20, self.pause, "two")
1505
1506         finished = []
1507         d3.addBoth(finished.append)
1508         self.failUnlessEqual(len(finished), 1)
1509         f = finished[0]
1510         self.failUnless(isinstance(f, Failure))
1511         self.failUnless(f.check(pipeline.PipelineError))
1512         r = repr(f.value)
1513         self.failUnless("ValueError" in r, r)
1514         f2 = f.value.error
1515         self.failUnless(f2.check(ValueError))
1516
1517         d4 = p.flush()
1518         flushed = []
1519         d4.addBoth(flushed.append)
1520         self.failUnlessEqual(len(flushed), 1)
1521         f = flushed[0]
1522         self.failUnless(isinstance(f, Failure))
1523         self.failUnless(f.check(pipeline.PipelineError))
1524         f2 = f.value.error
1525         self.failUnless(f2.check(ValueError))
1526
1527     def test_errors2(self):
1528         self.calls = []
1529         p = pipeline.Pipeline(100)
1530
1531         d1 = p.add(10, self.pause, "one")
1532         d2 = p.add(20, self.pause, "two")
1533         d3 = p.add(30, self.pause, "three")
1534         d4 = p.flush()
1535
1536         # one call fails, then the second one succeeds: make sure
1537         # ExpandableDeferredList tolerates the second one
1538
1539         flushed = []
1540         d4.addBoth(flushed.append)
1541         self.failUnlessEqual(flushed, [])
1542
1543         self.calls[0][0].errback(ValueError("oops"))
1544         self.failUnlessEqual(len(flushed), 1)
1545         f = flushed[0]
1546         self.failUnless(isinstance(f, Failure))
1547         self.failUnless(f.check(pipeline.PipelineError))
1548         f2 = f.value.error
1549         self.failUnless(f2.check(ValueError))
1550
1551         self.calls[1][0].callback("two-result")
1552         self.calls[2][0].errback(ValueError("three-error"))
1553
1554         del d1,d2,d3,d4
1555
1556 class SampleError(Exception):
1557     pass
1558
1559 class Log(unittest.TestCase):
1560     def test_err(self):
1561         if not hasattr(self, "flushLoggedErrors"):
1562             # without flushLoggedErrors, we can't get rid of the
1563             # twisted.log.err that tahoe_log records, so we can't keep this
1564             # test from [ERROR]ing
1565             raise unittest.SkipTest("needs flushLoggedErrors from Twisted-2.5.0")
1566         try:
1567             raise SampleError("simple sample")
1568         except:
1569             f = Failure()
1570         tahoe_log.err(format="intentional sample error",
1571                       failure=f, level=tahoe_log.OPERATIONAL, umid="wO9UoQ")
1572         self.flushLoggedErrors(SampleError)
1573
1574
1575 class SimpleSpans:
1576     # this is a simple+inefficient form of util.spans.Spans . We compare the
1577     # behavior of this reference model against the real (efficient) form.
1578
1579     def __init__(self, _span_or_start=None, length=None):
1580         self._have = set()
1581         if length is not None:
1582             for i in range(_span_or_start, _span_or_start+length):
1583                 self._have.add(i)
1584         elif _span_or_start:
1585             for (start,length) in _span_or_start:
1586                 self.add(start, length)
1587
1588     def add(self, start, length):
1589         for i in range(start, start+length):
1590             self._have.add(i)
1591         return self
1592
1593     def remove(self, start, length):
1594         for i in range(start, start+length):
1595             self._have.discard(i)
1596         return self
1597
1598     def each(self):
1599         return sorted(self._have)
1600
1601     def __iter__(self):
1602         items = sorted(self._have)
1603         prevstart = None
1604         prevend = None
1605         for i in items:
1606             if prevstart is None:
1607                 prevstart = prevend = i
1608                 continue
1609             if i == prevend+1:
1610                 prevend = i
1611                 continue
1612             yield (prevstart, prevend-prevstart+1)
1613             prevstart = prevend = i
1614         if prevstart is not None:
1615             yield (prevstart, prevend-prevstart+1)
1616
1617     def __len__(self):
1618         # this also gets us bool(s)
1619         return len(self._have)
1620
1621     def __add__(self, other):
1622         s = self.__class__(self)
1623         for (start, length) in other:
1624             s.add(start, length)
1625         return s
1626
1627     def __sub__(self, other):
1628         s = self.__class__(self)
1629         for (start, length) in other:
1630             s.remove(start, length)
1631         return s
1632
1633     def __iadd__(self, other):
1634         for (start, length) in other:
1635             self.add(start, length)
1636         return self
1637
1638     def __isub__(self, other):
1639         for (start, length) in other:
1640             self.remove(start, length)
1641         return self
1642
1643     def __and__(self, other):
1644         s = self.__class__()
1645         for i in other.each():
1646             if i in self._have:
1647                 s.add(i, 1)
1648         return s
1649
1650     def __contains__(self, (start,length)):
1651         for i in range(start, start+length):
1652             if i not in self._have:
1653                 return False
1654         return True
1655
1656 class ByteSpans(unittest.TestCase):
1657     def test_basic(self):
1658         s = Spans()
1659         self.failUnlessEqual(list(s), [])
1660         self.failIf(s)
1661         self.failIf((0,1) in s)
1662         self.failUnlessEqual(len(s), 0)
1663
1664         s1 = Spans(3, 4) # 3,4,5,6
1665         self._check1(s1)
1666
1667         s2 = Spans(s1)
1668         self._check1(s2)
1669
1670         s2.add(10,2) # 10,11
1671         self._check1(s1)
1672         self.failUnless((10,1) in s2)
1673         self.failIf((10,1) in s1)
1674         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
1675         self.failUnlessEqual(len(s2), 6)
1676
1677         s2.add(15,2).add(20,2)
1678         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
1679         self.failUnlessEqual(len(s2), 10)
1680
1681         s2.remove(4,3).remove(15,1)
1682         self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
1683         self.failUnlessEqual(len(s2), 6)
1684
1685         s1 = SimpleSpans(3, 4) # 3 4 5 6
1686         s2 = SimpleSpans(5, 4) # 5 6 7 8
1687         i = s1 & s2
1688         self.failUnlessEqual(list(i.each()), [5, 6])
1689
1690     def _check1(self, s):
1691         self.failUnlessEqual(list(s), [(3,4)])
1692         self.failUnless(s)
1693         self.failUnlessEqual(len(s), 4)
1694         self.failIf((0,1) in s)
1695         self.failUnless((3,4) in s)
1696         self.failUnless((3,1) in s)
1697         self.failUnless((5,2) in s)
1698         self.failUnless((6,1) in s)
1699         self.failIf((6,2) in s)
1700         self.failIf((7,1) in s)
1701         self.failUnlessEqual(list(s.each()), [3,4,5,6])
1702
1703     def test_math(self):
1704         s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
1705         s2 = Spans(5, 3) # 5,6,7
1706         s3 = Spans(8, 4) # 8,9,10,11
1707
1708         s = s1 - s2
1709         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1710         s = s1 - s3
1711         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1712         s = s2 - s3
1713         self.failUnlessEqual(list(s.each()), [5,6,7])
1714         s = s1 & s2
1715         self.failUnlessEqual(list(s.each()), [5,6,7])
1716         s = s2 & s1
1717         self.failUnlessEqual(list(s.each()), [5,6,7])
1718         s = s1 & s3
1719         self.failUnlessEqual(list(s.each()), [8,9])
1720         s = s3 & s1
1721         self.failUnlessEqual(list(s.each()), [8,9])
1722         s = s2 & s3
1723         self.failUnlessEqual(list(s.each()), [])
1724         s = s3 & s2
1725         self.failUnlessEqual(list(s.each()), [])
1726         s = Spans() & s3
1727         self.failUnlessEqual(list(s.each()), [])
1728         s = s3 & Spans()
1729         self.failUnlessEqual(list(s.each()), [])
1730
1731         s = s1 + s2
1732         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1733         s = s1 + s3
1734         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1735         s = s2 + s3
1736         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1737
1738         s = Spans(s1)
1739         s -= s2
1740         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1741         s = Spans(s1)
1742         s -= s3
1743         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1744         s = Spans(s2)
1745         s -= s3
1746         self.failUnlessEqual(list(s.each()), [5,6,7])
1747
1748         s = Spans(s1)
1749         s += s2
1750         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1751         s = Spans(s1)
1752         s += s3
1753         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1754         s = Spans(s2)
1755         s += s3
1756         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1757
1758     def test_random(self):
1759         # attempt to increase coverage of corner cases by comparing behavior
1760         # of a simple-but-slow model implementation against the
1761         # complex-but-fast actual implementation, in a large number of random
1762         # operations
1763         S1 = SimpleSpans
1764         S2 = Spans
1765         s1 = S1(); s2 = S2()
1766         seed = ""
1767         def _create(subseed):
1768             ns1 = S1(); ns2 = S2()
1769             for i in range(10):
1770                 what = md5(subseed+str(i)).hexdigest()
1771                 start = int(what[2:4], 16)
1772                 length = max(1,int(what[5:6], 16))
1773                 ns1.add(start, length); ns2.add(start, length)
1774             return ns1, ns2
1775
1776         #print
1777         for i in range(1000):
1778             what = md5(seed+str(i)).hexdigest()
1779             op = what[0]
1780             subop = what[1]
1781             start = int(what[2:4], 16)
1782             length = max(1,int(what[5:6], 16))
1783             #print what
1784             if op in "0":
1785                 if subop in "01234":
1786                     s1 = S1(); s2 = S2()
1787                 elif subop in "5678":
1788                     s1 = S1(start, length); s2 = S2(start, length)
1789                 else:
1790                     s1 = S1(s1); s2 = S2(s2)
1791                 #print "s2 = %s" % s2.dump()
1792             elif op in "123":
1793                 #print "s2.add(%d,%d)" % (start, length)
1794                 s1.add(start, length); s2.add(start, length)
1795             elif op in "456":
1796                 #print "s2.remove(%d,%d)" % (start, length)
1797                 s1.remove(start, length); s2.remove(start, length)
1798             elif op in "78":
1799                 ns1, ns2 = _create(what[7:11])
1800                 #print "s2 + %s" % ns2.dump()
1801                 s1 = s1 + ns1; s2 = s2 + ns2
1802             elif op in "9a":
1803                 ns1, ns2 = _create(what[7:11])
1804                 #print "%s - %s" % (s2.dump(), ns2.dump())
1805                 s1 = s1 - ns1; s2 = s2 - ns2
1806             elif op in "bc":
1807                 ns1, ns2 = _create(what[7:11])
1808                 #print "s2 += %s" % ns2.dump()
1809                 s1 += ns1; s2 += ns2
1810             elif op in "de":
1811                 ns1, ns2 = _create(what[7:11])
1812                 #print "%s -= %s" % (s2.dump(), ns2.dump())
1813                 s1 -= ns1; s2 -= ns2
1814             else:
1815                 ns1, ns2 = _create(what[7:11])
1816                 #print "%s &= %s" % (s2.dump(), ns2.dump())
1817                 s1 = s1 & ns1; s2 = s2 & ns2
1818             #print "s2 now %s" % s2.dump()
1819             self.failUnlessEqual(list(s1.each()), list(s2.each()))
1820             self.failUnlessEqual(len(s1), len(s2))
1821             self.failUnlessEqual(bool(s1), bool(s2))
1822             self.failUnlessEqual(list(s1), list(s2))
1823             for j in range(10):
1824                 what = md5(what[12:14]+str(j)).hexdigest()
1825                 start = int(what[2:4], 16)
1826                 length = max(1, int(what[5:6], 16))
1827                 span = (start, length)
1828                 self.failUnlessEqual(bool(span in s1), bool(span in s2))
1829
1830
1831     # s()
1832     # s(start,length)
1833     # s(s0)
1834     # s.add(start,length) : returns s
1835     # s.remove(start,length)
1836     # s.each() -> list of byte offsets, mostly for testing
1837     # list(s) -> list of (start,length) tuples, one per span
1838     # (start,length) in s -> True if (start..start+length-1) are all members
1839     #  NOT equivalent to x in list(s)
1840     # len(s) -> number of bytes, for testing, bool(), and accounting/limiting
1841     # bool(s)  (__len__)
1842     # s = s1+s2, s1-s2, +=s1, -=s1
1843
1844     def test_overlap(self):
1845         for a in range(20):
1846             for b in range(10):
1847                 for c in range(20):
1848                     for d in range(10):
1849                         self._test_overlap(a,b,c,d)
1850
1851     def _test_overlap(self, a, b, c, d):
1852         s1 = set(range(a,a+b))
1853         s2 = set(range(c,c+d))
1854         #print "---"
1855         #self._show_overlap(s1, "1")
1856         #self._show_overlap(s2, "2")
1857         o = overlap(a,b,c,d)
1858         expected = s1.intersection(s2)
1859         if not expected:
1860             self.failUnlessEqual(o, None)
1861         else:
1862             start,length = o
1863             so = set(range(start,start+length))
1864             #self._show(so, "o")
1865             self.failUnlessEqual(so, expected)
1866
1867     def _show_overlap(self, s, c):
1868         import sys
1869         out = sys.stdout
1870         if s:
1871             for i in range(max(s)):
1872                 if i in s:
1873                     out.write(c)
1874                 else:
1875                     out.write(" ")
1876         out.write("\n")
1877
1878 def extend(s, start, length, fill):
1879     if len(s) >= start+length:
1880         return s
1881     assert len(fill) == 1
1882     return s + fill*(start+length-len(s))
1883
1884 def replace(s, start, data):
1885     assert len(s) >= start+len(data)
1886     return s[:start] + data + s[start+len(data):]
1887
1888 class SimpleDataSpans:
1889     def __init__(self, other=None):
1890         self.missing = "" # "1" where missing, "0" where found
1891         self.data = ""
1892         if other:
1893             for (start, data) in other.get_chunks():
1894                 self.add(start, data)
1895
1896     def __len__(self):
1897         return len(self.missing.translate(None, "1"))
1898     def _dump(self):
1899         return [i for (i,c) in enumerate(self.missing) if c == "0"]
1900     def _have(self, start, length):
1901         m = self.missing[start:start+length]
1902         if not m or len(m)<length or int(m):
1903             return False
1904         return True
1905     def get_chunks(self):
1906         for i in self._dump():
1907             yield (i, self.data[i])
1908     def get_spans(self):
1909         return SimpleSpans([(start,len(data))
1910                             for (start,data) in self.get_chunks()])
1911     def get(self, start, length):
1912         if self._have(start, length):
1913             return self.data[start:start+length]
1914         return None
1915     def pop(self, start, length):
1916         data = self.get(start, length)
1917         if data:
1918             self.remove(start, length)
1919         return data
1920     def remove(self, start, length):
1921         self.missing = replace(extend(self.missing, start, length, "1"),
1922                                start, "1"*length)
1923     def add(self, start, data):
1924         self.missing = replace(extend(self.missing, start, len(data), "1"),
1925                                start, "0"*len(data))
1926         self.data = replace(extend(self.data, start, len(data), " "),
1927                             start, data)
1928
1929
1930 class StringSpans(unittest.TestCase):
1931     def do_basic(self, klass):
1932         ds = klass()
1933         self.failUnlessEqual(len(ds), 0)
1934         self.failUnlessEqual(list(ds._dump()), [])
1935         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 0)
1936         s = ds.get_spans()
1937         self.failUnlessEqual(ds.get(0, 4), None)
1938         self.failUnlessEqual(ds.pop(0, 4), None)
1939         ds.remove(0, 4)
1940
1941         ds.add(2, "four")
1942         self.failUnlessEqual(len(ds), 4)
1943         self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
1944         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1945         s = ds.get_spans()
1946         self.failUnless((2,2) in s)
1947         self.failUnlessEqual(ds.get(0, 4), None)
1948         self.failUnlessEqual(ds.pop(0, 4), None)
1949         self.failUnlessEqual(ds.get(4, 4), None)
1950
1951         ds2 = klass(ds)
1952         self.failUnlessEqual(len(ds2), 4)
1953         self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
1954         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 4)
1955         self.failUnlessEqual(ds2.get(0, 4), None)
1956         self.failUnlessEqual(ds2.pop(0, 4), None)
1957         self.failUnlessEqual(ds2.pop(2, 3), "fou")
1958         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 1)
1959         self.failUnlessEqual(ds2.get(2, 3), None)
1960         self.failUnlessEqual(ds2.get(5, 1), "r")
1961         self.failUnlessEqual(ds.get(2, 3), "fou")
1962         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1963
1964         ds.add(0, "23")
1965         self.failUnlessEqual(len(ds), 6)
1966         self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
1967         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 6)
1968         self.failUnlessEqual(ds.get(0, 4), "23fo")
1969         self.failUnlessEqual(ds.pop(0, 4), "23fo")
1970         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 2)
1971         self.failUnlessEqual(ds.get(0, 4), None)
1972         self.failUnlessEqual(ds.pop(0, 4), None)
1973
1974         ds = klass()
1975         ds.add(2, "four")
1976         ds.add(3, "ea")
1977         self.failUnlessEqual(ds.get(2, 4), "fear")
1978
1979     def do_scan(self, klass):
1980         # do a test with gaps and spans of size 1 and 2
1981         #  left=(1,11) * right=(1,11) * gapsize=(1,2)
1982         # 111, 112, 121, 122, 211, 212, 221, 222
1983         #    211
1984         #      121
1985         #         112
1986         #            212
1987         #               222
1988         #                   221
1989         #                      111
1990         #                        122
1991         #  11 1  1 11 11  11  1 1  111
1992         # 0123456789012345678901234567
1993         # abcdefghijklmnopqrstuvwxyz-=
1994         pieces = [(1, "bc"),
1995                   (4, "e"),
1996                   (7, "h"),
1997                   (9, "jk"),
1998                   (12, "mn"),
1999                   (16, "qr"),
2000                   (20, "u"),
2001                   (22, "w"),
2002                   (25, "z-="),
2003                   ]
2004         p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
2005         S = "abcdefghijklmnopqrstuvwxyz-="
2006         # TODO: when adding data, add capital letters, to make sure we aren't
2007         # just leaving the old data in place
2008         l = len(S)
2009         def base():
2010             ds = klass()
2011             for start, data in pieces:
2012                 ds.add(start, data)
2013             return ds
2014         def dump(s):
2015             p = set(s._dump())
2016             # wow, this is the first time I've ever wanted ?: in python
2017             # note: this requires python2.5
2018             d = "".join([(S[i] if i in p else " ") for i in range(l)])
2019             assert len(d) == l
2020             return d
2021         DEBUG = False
2022         for start in range(0, l):
2023             for end in range(start+1, l):
2024                 # add [start-end) to the baseline
2025                 which = "%d-%d" % (start, end-1)
2026                 p_added = set(range(start, end))
2027                 b = base()
2028                 if DEBUG:
2029                     print
2030                     print dump(b), which
2031                     add = klass(); add.add(start, S[start:end])
2032                     print dump(add)
2033                 b.add(start, S[start:end])
2034                 if DEBUG:
2035                     print dump(b)
2036                 # check that the new span is there
2037                 d = b.get(start, end-start)
2038                 self.failUnlessEqual(d, S[start:end], which)
2039                 # check that all the original pieces are still there
2040                 for t_start, t_data in pieces:
2041                     t_len = len(t_data)
2042                     self.failUnlessEqual(b.get(t_start, t_len),
2043                                          S[t_start:t_start+t_len],
2044                                          "%s %d+%d" % (which, t_start, t_len))
2045                 # check that a lot of subspans are mostly correct
2046                 for t_start in range(l):
2047                     for t_len in range(1,4):
2048                         d = b.get(t_start, t_len)
2049                         if d is not None:
2050                             which2 = "%s+(%d-%d)" % (which, t_start,
2051                                                      t_start+t_len-1)
2052                             self.failUnlessEqual(d, S[t_start:t_start+t_len],
2053                                                  which2)
2054                         # check that removing a subspan gives the right value
2055                         b2 = klass(b)
2056                         b2.remove(t_start, t_len)
2057                         removed = set(range(t_start, t_start+t_len))
2058                         for i in range(l):
2059                             exp = (((i in p_elements) or (i in p_added))
2060                                    and (i not in removed))
2061                             which2 = "%s-(%d-%d)" % (which, t_start,
2062                                                      t_start+t_len-1)
2063                             self.failUnlessEqual(bool(b2.get(i, 1)), exp,
2064                                                  which2+" %d" % i)
2065
2066     def test_test(self):
2067         self.do_basic(SimpleDataSpans)
2068         self.do_scan(SimpleDataSpans)
2069
2070     def test_basic(self):
2071         self.do_basic(DataSpans)
2072         self.do_scan(DataSpans)
2073
2074     def test_random(self):
2075         # attempt to increase coverage of corner cases by comparing behavior
2076         # of a simple-but-slow model implementation against the
2077         # complex-but-fast actual implementation, in a large number of random
2078         # operations
2079         S1 = SimpleDataSpans
2080         S2 = DataSpans
2081         s1 = S1(); s2 = S2()
2082         seed = ""
2083         def _randstr(length, seed):
2084             created = 0
2085             pieces = []
2086             while created < length:
2087                 piece = md5(seed + str(created)).hexdigest()
2088                 pieces.append(piece)
2089                 created += len(piece)
2090             return "".join(pieces)[:length]
2091         def _create(subseed):
2092             ns1 = S1(); ns2 = S2()
2093             for i in range(10):
2094                 what = md5(subseed+str(i)).hexdigest()
2095                 start = int(what[2:4], 16)
2096                 length = max(1,int(what[5:6], 16))
2097                 ns1.add(start, _randstr(length, what[7:9]));
2098                 ns2.add(start, _randstr(length, what[7:9]))
2099             return ns1, ns2
2100
2101         #print
2102         for i in range(1000):
2103             what = md5(seed+str(i)).hexdigest()
2104             op = what[0]
2105             subop = what[1]
2106             start = int(what[2:4], 16)
2107             length = max(1,int(what[5:6], 16))
2108             #print what
2109             if op in "0":
2110                 if subop in "0123456":
2111                     s1 = S1(); s2 = S2()
2112                 else:
2113                     s1, s2 = _create(what[7:11])
2114                 #print "s2 = %s" % list(s2._dump())
2115             elif op in "123456":
2116                 #print "s2.add(%d,%d)" % (start, length)
2117                 s1.add(start, _randstr(length, what[7:9]));
2118                 s2.add(start, _randstr(length, what[7:9]))
2119             elif op in "789abc":
2120                 #print "s2.remove(%d,%d)" % (start, length)
2121                 s1.remove(start, length); s2.remove(start, length)
2122             else:
2123                 #print "s2.pop(%d,%d)" % (start, length)
2124                 d1 = s1.pop(start, length); d2 = s2.pop(start, length)
2125                 self.failUnlessEqual(d1, d2)
2126             #print "s1 now %s" % list(s1._dump())
2127             #print "s2 now %s" % list(s2._dump())
2128             self.failUnlessEqual(len(s1), len(s2))
2129             self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
2130             for j in range(100):
2131                 what = md5(what[12:14]+str(j)).hexdigest()
2132                 start = int(what[2:4], 16)
2133                 length = max(1, int(what[5:6], 16))
2134                 d1 = s1.get(start, length); d2 = s2.get(start, length)
2135                 self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))